diff options
author | Tianon Gravi <admwiggin@gmail.com> | 2015-01-15 11:54:00 -0700 |
---|---|---|
committer | Tianon Gravi <admwiggin@gmail.com> | 2015-01-15 11:54:00 -0700 |
commit | f154da9e12608589e8d5f0508f908a0c3e88a1bb (patch) | |
tree | f8255d51e10c6f1e0ed69702200b966c9556a431 /src/net/http | |
parent | 8d8329ed5dfb9622c82a9fbec6fd99a580f9c9f6 (diff) | |
download | golang-upstream/1.4.tar.gz |
Imported Upstream version 1.4upstream/1.4
Diffstat (limited to 'src/net/http')
70 files changed, 26148 insertions, 0 deletions
diff --git a/src/net/http/cgi/child.go b/src/net/http/cgi/child.go new file mode 100644 index 000000000..45fc2e57c --- /dev/null +++ b/src/net/http/cgi/child.go @@ -0,0 +1,206 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements CGI from the perspective of a child +// process. + +package cgi + +import ( + "bufio" + "crypto/tls" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "os" + "strconv" + "strings" +) + +// Request returns the HTTP request as represented in the current +// environment. This assumes the current program is being run +// by a web server in a CGI environment. +// The returned Request's Body is populated, if applicable. +func Request() (*http.Request, error) { + r, err := RequestFromMap(envMap(os.Environ())) + if err != nil { + return nil, err + } + if r.ContentLength > 0 { + r.Body = ioutil.NopCloser(io.LimitReader(os.Stdin, r.ContentLength)) + } + return r, nil +} + +func envMap(env []string) map[string]string { + m := make(map[string]string) + for _, kv := range env { + if idx := strings.Index(kv, "="); idx != -1 { + m[kv[:idx]] = kv[idx+1:] + } + } + return m +} + +// RequestFromMap creates an http.Request from CGI variables. +// The returned Request's Body field is not populated. +func RequestFromMap(params map[string]string) (*http.Request, error) { + r := new(http.Request) + r.Method = params["REQUEST_METHOD"] + if r.Method == "" { + return nil, errors.New("cgi: no REQUEST_METHOD in environment") + } + + r.Proto = params["SERVER_PROTOCOL"] + var ok bool + r.ProtoMajor, r.ProtoMinor, ok = http.ParseHTTPVersion(r.Proto) + if !ok { + return nil, errors.New("cgi: invalid SERVER_PROTOCOL version") + } + + r.Close = true + r.Trailer = http.Header{} + r.Header = http.Header{} + + r.Host = params["HTTP_HOST"] + + if lenstr := params["CONTENT_LENGTH"]; lenstr != "" { + clen, err := strconv.ParseInt(lenstr, 10, 64) + if err != nil { + return nil, errors.New("cgi: bad CONTENT_LENGTH in environment: " + lenstr) + } + r.ContentLength = clen + } + + if ct := params["CONTENT_TYPE"]; ct != "" { + r.Header.Set("Content-Type", ct) + } + + // Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers + for k, v := range params { + if !strings.HasPrefix(k, "HTTP_") || k == "HTTP_HOST" { + continue + } + r.Header.Add(strings.Replace(k[5:], "_", "-", -1), v) + } + + // TODO: cookies. parsing them isn't exported, though. + + uriStr := params["REQUEST_URI"] + if uriStr == "" { + // Fallback to SCRIPT_NAME, PATH_INFO and QUERY_STRING. + uriStr = params["SCRIPT_NAME"] + params["PATH_INFO"] + s := params["QUERY_STRING"] + if s != "" { + uriStr += "?" + s + } + } + + // There's apparently a de-facto standard for this. + // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636 + if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" { + r.TLS = &tls.ConnectionState{HandshakeComplete: true} + } + + if r.Host != "" { + // Hostname is provided, so we can reasonably construct a URL. + rawurl := r.Host + uriStr + if r.TLS == nil { + rawurl = "http://" + rawurl + } else { + rawurl = "https://" + rawurl + } + url, err := url.Parse(rawurl) + if err != nil { + return nil, errors.New("cgi: failed to parse host and REQUEST_URI into a URL: " + rawurl) + } + r.URL = url + } + // Fallback logic if we don't have a Host header or the URL + // failed to parse + if r.URL == nil { + url, err := url.Parse(uriStr) + if err != nil { + return nil, errors.New("cgi: failed to parse REQUEST_URI into a URL: " + uriStr) + } + r.URL = url + } + + // Request.RemoteAddr has its port set by Go's standard http + // server, so we do here too. We don't have one, though, so we + // use a dummy one. + r.RemoteAddr = net.JoinHostPort(params["REMOTE_ADDR"], "0") + + return r, nil +} + +// Serve executes the provided Handler on the currently active CGI +// request, if any. If there's no current CGI environment +// an error is returned. The provided handler may be nil to use +// http.DefaultServeMux. +func Serve(handler http.Handler) error { + req, err := Request() + if err != nil { + return err + } + if handler == nil { + handler = http.DefaultServeMux + } + rw := &response{ + req: req, + header: make(http.Header), + bufw: bufio.NewWriter(os.Stdout), + } + handler.ServeHTTP(rw, req) + rw.Write(nil) // make sure a response is sent + if err = rw.bufw.Flush(); err != nil { + return err + } + return nil +} + +type response struct { + req *http.Request + header http.Header + bufw *bufio.Writer + headerSent bool +} + +func (r *response) Flush() { + r.bufw.Flush() +} + +func (r *response) Header() http.Header { + return r.header +} + +func (r *response) Write(p []byte) (n int, err error) { + if !r.headerSent { + r.WriteHeader(http.StatusOK) + } + return r.bufw.Write(p) +} + +func (r *response) WriteHeader(code int) { + if r.headerSent { + // Note: explicitly using Stderr, as Stdout is our HTTP output. + fmt.Fprintf(os.Stderr, "CGI attempted to write header twice on request for %s", r.req.URL) + return + } + r.headerSent = true + fmt.Fprintf(r.bufw, "Status: %d %s\r\n", code, http.StatusText(code)) + + // Set a default Content-Type + if _, hasType := r.header["Content-Type"]; !hasType { + r.header.Add("Content-Type", "text/html; charset=utf-8") + } + + r.header.Write(r.bufw) + r.bufw.WriteString("\r\n") + r.bufw.Flush() +} diff --git a/src/net/http/cgi/child_test.go b/src/net/http/cgi/child_test.go new file mode 100644 index 000000000..075d8411b --- /dev/null +++ b/src/net/http/cgi/child_test.go @@ -0,0 +1,131 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for CGI (the child process perspective) + +package cgi + +import ( + "testing" +) + +func TestRequest(t *testing.T) { + env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", + "REQUEST_METHOD": "GET", + "HTTP_HOST": "example.com", + "HTTP_REFERER": "elsewhere", + "HTTP_USER_AGENT": "goclient", + "HTTP_FOO_BAR": "baz", + "REQUEST_URI": "/path?a=b", + "CONTENT_LENGTH": "123", + "CONTENT_TYPE": "text/xml", + "REMOTE_ADDR": "5.6.7.8", + } + req, err := RequestFromMap(env) + if err != nil { + t.Fatalf("RequestFromMap: %v", err) + } + if g, e := req.UserAgent(), "goclient"; e != g { + t.Errorf("expected UserAgent %q; got %q", e, g) + } + if g, e := req.Method, "GET"; e != g { + t.Errorf("expected Method %q; got %q", e, g) + } + if g, e := req.Header.Get("Content-Type"), "text/xml"; e != g { + t.Errorf("expected Content-Type %q; got %q", e, g) + } + if g, e := req.ContentLength, int64(123); e != g { + t.Errorf("expected ContentLength %d; got %d", e, g) + } + if g, e := req.Referer(), "elsewhere"; e != g { + t.Errorf("expected Referer %q; got %q", e, g) + } + if req.Header == nil { + t.Fatalf("unexpected nil Header") + } + if g, e := req.Header.Get("Foo-Bar"), "baz"; e != g { + t.Errorf("expected Foo-Bar %q; got %q", e, g) + } + if g, e := req.URL.String(), "http://example.com/path?a=b"; e != g { + t.Errorf("expected URL %q; got %q", e, g) + } + if g, e := req.FormValue("a"), "b"; e != g { + t.Errorf("expected FormValue(a) %q; got %q", e, g) + } + if req.Trailer == nil { + t.Errorf("unexpected nil Trailer") + } + if req.TLS != nil { + t.Errorf("expected nil TLS") + } + if e, g := "5.6.7.8:0", req.RemoteAddr; e != g { + t.Errorf("RemoteAddr: got %q; want %q", g, e) + } +} + +func TestRequestWithTLS(t *testing.T) { + env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", + "REQUEST_METHOD": "GET", + "HTTP_HOST": "example.com", + "HTTP_REFERER": "elsewhere", + "REQUEST_URI": "/path?a=b", + "CONTENT_TYPE": "text/xml", + "HTTPS": "1", + "REMOTE_ADDR": "5.6.7.8", + } + req, err := RequestFromMap(env) + if err != nil { + t.Fatalf("RequestFromMap: %v", err) + } + if g, e := req.URL.String(), "https://example.com/path?a=b"; e != g { + t.Errorf("expected URL %q; got %q", e, g) + } + if req.TLS == nil { + t.Errorf("expected non-nil TLS") + } +} + +func TestRequestWithoutHost(t *testing.T) { + env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", + "HTTP_HOST": "", + "REQUEST_METHOD": "GET", + "REQUEST_URI": "/path?a=b", + "CONTENT_LENGTH": "123", + } + req, err := RequestFromMap(env) + if err != nil { + t.Fatalf("RequestFromMap: %v", err) + } + if req.URL == nil { + t.Fatalf("unexpected nil URL") + } + if g, e := req.URL.String(), "/path?a=b"; e != g { + t.Errorf("URL = %q; want %q", g, e) + } +} + +func TestRequestWithoutRequestURI(t *testing.T) { + env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", + "HTTP_HOST": "example.com", + "REQUEST_METHOD": "GET", + "SCRIPT_NAME": "/dir/scriptname", + "PATH_INFO": "/p1/p2", + "QUERY_STRING": "a=1&b=2", + "CONTENT_LENGTH": "123", + } + req, err := RequestFromMap(env) + if err != nil { + t.Fatalf("RequestFromMap: %v", err) + } + if req.URL == nil { + t.Fatalf("unexpected nil URL") + } + if g, e := req.URL.String(), "http://example.com/dir/scriptname/p1/p2?a=1&b=2"; e != g { + t.Errorf("URL = %q; want %q", g, e) + } +} diff --git a/src/net/http/cgi/host.go b/src/net/http/cgi/host.go new file mode 100644 index 000000000..ec95a972c --- /dev/null +++ b/src/net/http/cgi/host.go @@ -0,0 +1,377 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements the host side of CGI (being the webserver +// parent process). + +// Package cgi implements CGI (Common Gateway Interface) as specified +// in RFC 3875. +// +// Note that using CGI means starting a new process to handle each +// request, which is typically less efficient than using a +// long-running server. This package is intended primarily for +// compatibility with existing systems. +package cgi + +import ( + "bufio" + "fmt" + "io" + "log" + "net/http" + "os" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" +) + +var trailingPort = regexp.MustCompile(`:([0-9]+)$`) + +var osDefaultInheritEnv = map[string][]string{ + "darwin": {"DYLD_LIBRARY_PATH"}, + "freebsd": {"LD_LIBRARY_PATH"}, + "hpux": {"LD_LIBRARY_PATH", "SHLIB_PATH"}, + "irix": {"LD_LIBRARY_PATH", "LD_LIBRARYN32_PATH", "LD_LIBRARY64_PATH"}, + "linux": {"LD_LIBRARY_PATH"}, + "openbsd": {"LD_LIBRARY_PATH"}, + "solaris": {"LD_LIBRARY_PATH", "LD_LIBRARY_PATH_32", "LD_LIBRARY_PATH_64"}, + "windows": {"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"}, +} + +// Handler runs an executable in a subprocess with a CGI environment. +type Handler struct { + Path string // path to the CGI executable + Root string // root URI prefix of handler or empty for "/" + + // Dir specifies the CGI executable's working directory. + // If Dir is empty, the base directory of Path is used. + // If Path has no base directory, the current working + // directory is used. + Dir string + + Env []string // extra environment variables to set, if any, as "key=value" + InheritEnv []string // environment variables to inherit from host, as "key" + Logger *log.Logger // optional log for errors or nil to use log.Print + Args []string // optional arguments to pass to child process + + // PathLocationHandler specifies the root http Handler that + // should handle internal redirects when the CGI process + // returns a Location header value starting with a "/", as + // specified in RFC 3875 ยง 6.3.2. This will likely be + // http.DefaultServeMux. + // + // If nil, a CGI response with a local URI path is instead sent + // back to the client and not redirected internally. + PathLocationHandler http.Handler +} + +// removeLeadingDuplicates remove leading duplicate in environments. +// It's possible to override environment like following. +// cgi.Handler{ +// ... +// Env: []string{"SCRIPT_FILENAME=foo.php"}, +// } +func removeLeadingDuplicates(env []string) (ret []string) { + n := len(env) + for i := 0; i < n; i++ { + e := env[i] + s := strings.SplitN(e, "=", 2)[0] + found := false + for j := i + 1; j < n; j++ { + if s == strings.SplitN(env[j], "=", 2)[0] { + found = true + break + } + } + if !found { + ret = append(ret, e) + } + } + return +} + +func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + root := h.Root + if root == "" { + root = "/" + } + + if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" { + rw.WriteHeader(http.StatusBadRequest) + rw.Write([]byte("Chunked request bodies are not supported by CGI.")) + return + } + + pathInfo := req.URL.Path + if root != "/" && strings.HasPrefix(pathInfo, root) { + pathInfo = pathInfo[len(root):] + } + + port := "80" + if matches := trailingPort.FindStringSubmatch(req.Host); len(matches) != 0 { + port = matches[1] + } + + env := []string{ + "SERVER_SOFTWARE=go", + "SERVER_NAME=" + req.Host, + "SERVER_PROTOCOL=HTTP/1.1", + "HTTP_HOST=" + req.Host, + "GATEWAY_INTERFACE=CGI/1.1", + "REQUEST_METHOD=" + req.Method, + "QUERY_STRING=" + req.URL.RawQuery, + "REQUEST_URI=" + req.URL.RequestURI(), + "PATH_INFO=" + pathInfo, + "SCRIPT_NAME=" + root, + "SCRIPT_FILENAME=" + h.Path, + "REMOTE_ADDR=" + req.RemoteAddr, + "REMOTE_HOST=" + req.RemoteAddr, + "SERVER_PORT=" + port, + } + + if req.TLS != nil { + env = append(env, "HTTPS=on") + } + + for k, v := range req.Header { + k = strings.Map(upperCaseAndUnderscore, k) + joinStr := ", " + if k == "COOKIE" { + joinStr = "; " + } + env = append(env, "HTTP_"+k+"="+strings.Join(v, joinStr)) + } + + if req.ContentLength > 0 { + env = append(env, fmt.Sprintf("CONTENT_LENGTH=%d", req.ContentLength)) + } + if ctype := req.Header.Get("Content-Type"); ctype != "" { + env = append(env, "CONTENT_TYPE="+ctype) + } + + if h.Env != nil { + env = append(env, h.Env...) + } + + envPath := os.Getenv("PATH") + if envPath == "" { + envPath = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin" + } + env = append(env, "PATH="+envPath) + + for _, e := range h.InheritEnv { + if v := os.Getenv(e); v != "" { + env = append(env, e+"="+v) + } + } + + for _, e := range osDefaultInheritEnv[runtime.GOOS] { + if v := os.Getenv(e); v != "" { + env = append(env, e+"="+v) + } + } + + env = removeLeadingDuplicates(env) + + var cwd, path string + if h.Dir != "" { + path = h.Path + cwd = h.Dir + } else { + cwd, path = filepath.Split(h.Path) + } + if cwd == "" { + cwd = "." + } + + internalError := func(err error) { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("CGI error: %v", err) + } + + cmd := &exec.Cmd{ + Path: path, + Args: append([]string{h.Path}, h.Args...), + Dir: cwd, + Env: env, + Stderr: os.Stderr, // for now + } + if req.ContentLength != 0 { + cmd.Stdin = req.Body + } + stdoutRead, err := cmd.StdoutPipe() + if err != nil { + internalError(err) + return + } + + err = cmd.Start() + if err != nil { + internalError(err) + return + } + if hook := testHookStartProcess; hook != nil { + hook(cmd.Process) + } + defer cmd.Wait() + defer stdoutRead.Close() + + linebody := bufio.NewReaderSize(stdoutRead, 1024) + headers := make(http.Header) + statusCode := 0 + headerLines := 0 + sawBlankLine := false + for { + line, isPrefix, err := linebody.ReadLine() + if isPrefix { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: long header line from subprocess.") + return + } + if err == io.EOF { + break + } + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: error reading headers: %v", err) + return + } + if len(line) == 0 { + sawBlankLine = true + break + } + headerLines++ + parts := strings.SplitN(string(line), ":", 2) + if len(parts) < 2 { + h.printf("cgi: bogus header line: %s", string(line)) + continue + } + header, val := parts[0], parts[1] + header = strings.TrimSpace(header) + val = strings.TrimSpace(val) + switch { + case header == "Status": + if len(val) < 3 { + h.printf("cgi: bogus status (short): %q", val) + return + } + code, err := strconv.Atoi(val[0:3]) + if err != nil { + h.printf("cgi: bogus status: %q", val) + h.printf("cgi: line was %q", line) + return + } + statusCode = code + default: + headers.Add(header, val) + } + } + if headerLines == 0 || !sawBlankLine { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: no headers") + return + } + + if loc := headers.Get("Location"); loc != "" { + if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil { + h.handleInternalRedirect(rw, req, loc) + return + } + if statusCode == 0 { + statusCode = http.StatusFound + } + } + + if statusCode == 0 && headers.Get("Content-Type") == "" { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: missing required Content-Type in headers") + return + } + + if statusCode == 0 { + statusCode = http.StatusOK + } + + // Copy headers to rw's headers, after we've decided not to + // go into handleInternalRedirect, which won't want its rw + // headers to have been touched. + for k, vv := range headers { + for _, v := range vv { + rw.Header().Add(k, v) + } + } + + rw.WriteHeader(statusCode) + + _, err = io.Copy(rw, linebody) + if err != nil { + h.printf("cgi: copy error: %v", err) + // And kill the child CGI process so we don't hang on + // the deferred cmd.Wait above if the error was just + // the client (rw) going away. If it was a read error + // (because the child died itself), then the extra + // kill of an already-dead process is harmless (the PID + // won't be reused until the Wait above). + cmd.Process.Kill() + } +} + +func (h *Handler) printf(format string, v ...interface{}) { + if h.Logger != nil { + h.Logger.Printf(format, v...) + } else { + log.Printf(format, v...) + } +} + +func (h *Handler) handleInternalRedirect(rw http.ResponseWriter, req *http.Request, path string) { + url, err := req.URL.Parse(path) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: error resolving local URI path %q: %v", path, err) + return + } + // TODO: RFC 3875 isn't clear if only GET is supported, but it + // suggests so: "Note that any message-body attached to the + // request (such as for a POST request) may not be available + // to the resource that is the target of the redirect." We + // should do some tests against Apache to see how it handles + // POST, HEAD, etc. Does the internal redirect get the same + // method or just GET? What about incoming headers? + // (e.g. Cookies) Which headers, if any, are copied into the + // second request? + newReq := &http.Request{ + Method: "GET", + URL: url, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: url.Host, + RemoteAddr: req.RemoteAddr, + TLS: req.TLS, + } + h.PathLocationHandler.ServeHTTP(rw, newReq) +} + +func upperCaseAndUnderscore(r rune) rune { + switch { + case r >= 'a' && r <= 'z': + return r - ('a' - 'A') + case r == '-': + return '_' + case r == '=': + // Maybe not part of the CGI 'spec' but would mess up + // the environment in any case, as Go represents the + // environment as a slice of "key=value" strings. + return '_' + } + // TODO: other transformations in spec or practice? + return r +} + +var testHookStartProcess func(*os.Process) // nil except for some tests diff --git a/src/net/http/cgi/host_test.go b/src/net/http/cgi/host_test.go new file mode 100644 index 000000000..8c16e6897 --- /dev/null +++ b/src/net/http/cgi/host_test.go @@ -0,0 +1,461 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for package cgi + +package cgi + +import ( + "bufio" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "runtime" + "strconv" + "strings" + "testing" + "time" +) + +func newRequest(httpreq string) *http.Request { + buf := bufio.NewReader(strings.NewReader(httpreq)) + req, err := http.ReadRequest(buf) + if err != nil { + panic("cgi: bogus http request in test: " + httpreq) + } + req.RemoteAddr = "1.2.3.4" + return req +} + +func runCgiTest(t *testing.T, h *Handler, httpreq string, expectedMap map[string]string) *httptest.ResponseRecorder { + rw := httptest.NewRecorder() + req := newRequest(httpreq) + h.ServeHTTP(rw, req) + + // Make a map to hold the test map that the CGI returns. + m := make(map[string]string) + m["_body"] = rw.Body.String() + linesRead := 0 +readlines: + for { + line, err := rw.Body.ReadString('\n') + switch { + case err == io.EOF: + break readlines + case err != nil: + t.Fatalf("unexpected error reading from CGI: %v", err) + } + linesRead++ + trimmedLine := strings.TrimRight(line, "\r\n") + split := strings.SplitN(trimmedLine, "=", 2) + if len(split) != 2 { + t.Fatalf("Unexpected %d parts from invalid line number %v: %q; existing map=%v", + len(split), linesRead, line, m) + } + m[split[0]] = split[1] + } + + for key, expected := range expectedMap { + got := m[key] + if key == "cwd" { + // For Windows. golang.org/issue/4645. + fi1, _ := os.Stat(got) + fi2, _ := os.Stat(expected) + if os.SameFile(fi1, fi2) { + got = expected + } + } + if got != expected { + t.Errorf("for key %q got %q; expected %q", key, got, expected) + } + } + return rw +} + +var cgiTested, cgiWorks bool + +func check(t *testing.T) { + if !cgiTested { + cgiTested = true + cgiWorks = exec.Command("./testdata/test.cgi").Run() == nil + } + if !cgiWorks { + // No Perl on Windows, needed by test.cgi + // TODO: make the child process be Go, not Perl. + t.Skip("Skipping test: test.cgi failed.") + } +} + +func TestCGIBasicGet(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "test": "Hello CGI", + "param-a": "b", + "param-foo": "bar", + "env-GATEWAY_INTERFACE": "CGI/1.1", + "env-HTTP_HOST": "example.com", + "env-PATH_INFO": "", + "env-QUERY_STRING": "foo=bar&a=b", + "env-REMOTE_ADDR": "1.2.3.4", + "env-REMOTE_HOST": "1.2.3.4", + "env-REQUEST_METHOD": "GET", + "env-REQUEST_URI": "/test.cgi?foo=bar&a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/test.cgi", + "env-SERVER_NAME": "example.com", + "env-SERVER_PORT": "80", + "env-SERVER_SOFTWARE": "go", + } + replay := runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) + + if expected, got := "text/html", replay.Header().Get("Content-Type"); got != expected { + t.Errorf("got a Content-Type of %q; expected %q", got, expected) + } + if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected { + t.Errorf("got a X-Test-Header of %q; expected %q", got, expected) + } +} + +func TestCGIBasicGetAbsPath(t *testing.T) { + check(t) + pwd, err := os.Getwd() + if err != nil { + t.Fatalf("getwd error: %v", err) + } + h := &Handler{ + Path: pwd + "/testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "env-REQUEST_URI": "/test.cgi?foo=bar&a=b", + "env-SCRIPT_FILENAME": pwd + "/testdata/test.cgi", + "env-SCRIPT_NAME": "/test.cgi", + } + runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestPathInfo(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "param-a": "b", + "env-PATH_INFO": "/extrapath", + "env-QUERY_STRING": "a=b", + "env-REQUEST_URI": "/test.cgi/extrapath?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/test.cgi", + } + runCgiTest(t, h, "GET /test.cgi/extrapath?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestPathInfoDirRoot(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/myscript/", + } + expectedMap := map[string]string{ + "env-PATH_INFO": "bar", + "env-QUERY_STRING": "a=b", + "env-REQUEST_URI": "/myscript/bar?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/myscript/", + } + runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestDupHeaders(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + } + expectedMap := map[string]string{ + "env-REQUEST_URI": "/myscript/bar?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-HTTP_COOKIE": "nom=NOM; yum=YUM", + "env-HTTP_X_FOO": "val1, val2", + } + runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+ + "Cookie: nom=NOM\n"+ + "Cookie: yum=YUM\n"+ + "X-Foo: val1\n"+ + "X-Foo: val2\n"+ + "Host: example.com\n\n", + expectedMap) +} + +func TestPathInfoNoRoot(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "", + } + expectedMap := map[string]string{ + "env-PATH_INFO": "/bar", + "env-QUERY_STRING": "a=b", + "env-REQUEST_URI": "/bar?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/", + } + runCgiTest(t, h, "GET /bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestCGIBasicPost(t *testing.T) { + check(t) + postReq := `POST /test.cgi?a=b HTTP/1.0 +Host: example.com +Content-Type: application/x-www-form-urlencoded +Content-Length: 15 + +postfoo=postbar` + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "test": "Hello CGI", + "param-postfoo": "postbar", + "env-REQUEST_METHOD": "POST", + "env-CONTENT_LENGTH": "15", + "env-REQUEST_URI": "/test.cgi?a=b", + } + runCgiTest(t, h, postReq, expectedMap) +} + +func chunk(s string) string { + return fmt.Sprintf("%x\r\n%s\r\n", len(s), s) +} + +// The CGI spec doesn't allow chunked requests. +func TestCGIPostChunked(t *testing.T) { + check(t) + postReq := `POST /test.cgi?a=b HTTP/1.1 +Host: example.com +Content-Type: application/x-www-form-urlencoded +Transfer-Encoding: chunked + +` + chunk("postfoo") + chunk("=") + chunk("postbar") + chunk("") + + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{} + resp := runCgiTest(t, h, postReq, expectedMap) + if got, expected := resp.Code, http.StatusBadRequest; got != expected { + t.Fatalf("Expected %v response code from chunked request body; got %d", + expected, got) + } +} + +func TestRedirect(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + rec := runCgiTest(t, h, "GET /test.cgi?loc=http://foo.com/ HTTP/1.0\nHost: example.com\n\n", nil) + if e, g := 302, rec.Code; e != g { + t.Errorf("expected status code %d; got %d", e, g) + } + if e, g := "http://foo.com/", rec.Header().Get("Location"); e != g { + t.Errorf("expected Location header of %q; got %q", e, g) + } +} + +func TestInternalRedirect(t *testing.T) { + check(t) + baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path) + fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr) + }) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + PathLocationHandler: baseHandler, + } + expectedMap := map[string]string{ + "basepath": "/foo", + "remoteaddr": "1.2.3.4", + } + runCgiTest(t, h, "GET /test.cgi?loc=/foo HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +// TestCopyError tests that we kill the process if there's an error copying +// its output. (for example, from the client having gone away) +func TestCopyError(t *testing.T) { + check(t) + if runtime.GOOS == "windows" { + t.Skipf("skipping test on %q", runtime.GOOS) + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + ts := httptest.NewServer(h) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + req, _ := http.NewRequest("GET", "http://example.com/test.cgi?bigresponse=1", nil) + err = req.Write(conn) + if err != nil { + t.Fatalf("Write: %v", err) + } + + res, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + t.Fatalf("ReadResponse: %v", err) + } + + pidstr := res.Header.Get("X-CGI-Pid") + if pidstr == "" { + t.Fatalf("expected an X-CGI-Pid header in response") + } + pid, err := strconv.Atoi(pidstr) + if err != nil { + t.Fatalf("invalid X-CGI-Pid value") + } + + var buf [5000]byte + n, err := io.ReadFull(res.Body, buf[:]) + if err != nil { + t.Fatalf("ReadFull: %d bytes, %v", n, err) + } + + childRunning := func() bool { + return isProcessRunning(t, pid) + } + + if !childRunning() { + t.Fatalf("pre-conn.Close, expected child to be running") + } + conn.Close() + + tries := 0 + for tries < 25 && childRunning() { + time.Sleep(50 * time.Millisecond * time.Duration(tries)) + tries++ + } + if childRunning() { + t.Fatalf("post-conn.Close, expected child to be gone") + } +} + +func TestDirUnix(t *testing.T) { + check(t) + if runtime.GOOS == "windows" { + t.Skipf("skipping test on %q", runtime.GOOS) + } + cwd, _ := os.Getwd() + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + Dir: cwd, + } + expectedMap := map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) + + cwd, _ = os.Getwd() + cwd = filepath.Join(cwd, "testdata") + h = &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap = map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestDirWindows(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Skipping windows specific test.") + } + + cgifile, _ := filepath.Abs("testdata/test.cgi") + + var perl string + var err error + perl, err = exec.LookPath("perl") + if err != nil { + t.Skip("Skipping test: perl not found.") + } + perl, _ = filepath.Abs(perl) + + cwd, _ := os.Getwd() + h := &Handler{ + Path: perl, + Root: "/test.cgi", + Dir: cwd, + Args: []string{cgifile}, + Env: []string{"SCRIPT_FILENAME=" + cgifile}, + } + expectedMap := map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) + + // If not specify Dir on windows, working directory should be + // base directory of perl. + cwd, _ = filepath.Split(perl) + if cwd != "" && cwd[len(cwd)-1] == filepath.Separator { + cwd = cwd[:len(cwd)-1] + } + h = &Handler{ + Path: perl, + Root: "/test.cgi", + Args: []string{cgifile}, + Env: []string{"SCRIPT_FILENAME=" + cgifile}, + } + expectedMap = map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestEnvOverride(t *testing.T) { + cgifile, _ := filepath.Abs("testdata/test.cgi") + + var perl string + var err error + perl, err = exec.LookPath("perl") + if err != nil { + t.Skipf("Skipping test: perl not found.") + } + perl, _ = filepath.Abs(perl) + + cwd, _ := os.Getwd() + h := &Handler{ + Path: perl, + Root: "/test.cgi", + Dir: cwd, + Args: []string{cgifile}, + Env: []string{ + "SCRIPT_FILENAME=" + cgifile, + "REQUEST_URI=/foo/bar"}, + } + expectedMap := map[string]string{ + "cwd": cwd, + "env-SCRIPT_FILENAME": cgifile, + "env-REQUEST_URI": "/foo/bar", + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) +} diff --git a/src/net/http/cgi/matryoshka_test.go b/src/net/http/cgi/matryoshka_test.go new file mode 100644 index 000000000..18c4803e7 --- /dev/null +++ b/src/net/http/cgi/matryoshka_test.go @@ -0,0 +1,228 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests a Go CGI program running under a Go CGI host process. +// Further, the two programs are the same binary, just checking +// their environment to figure out what mode to run in. + +package cgi + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "runtime" + "testing" + "time" +) + +// This test is a CGI host (testing host.go) that runs its own binary +// as a child process testing the other half of CGI (child.go). +func TestHostingOurselves(t *testing.T) { + if runtime.GOOS == "nacl" { + t.Skip("skipping on nacl") + } + + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + expectedMap := map[string]string{ + "test": "Hello CGI-in-CGI", + "param-a": "b", + "param-foo": "bar", + "env-GATEWAY_INTERFACE": "CGI/1.1", + "env-HTTP_HOST": "example.com", + "env-PATH_INFO": "", + "env-QUERY_STRING": "foo=bar&a=b", + "env-REMOTE_ADDR": "1.2.3.4", + "env-REMOTE_HOST": "1.2.3.4", + "env-REQUEST_METHOD": "GET", + "env-REQUEST_URI": "/test.go?foo=bar&a=b", + "env-SCRIPT_FILENAME": os.Args[0], + "env-SCRIPT_NAME": "/test.go", + "env-SERVER_NAME": "example.com", + "env-SERVER_PORT": "80", + "env-SERVER_SOFTWARE": "go", + } + replay := runCgiTest(t, h, "GET /test.go?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) + + if expected, got := "text/html; charset=utf-8", replay.Header().Get("Content-Type"); got != expected { + t.Errorf("got a Content-Type of %q; expected %q", got, expected) + } + if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected { + t.Errorf("got a X-Test-Header of %q; expected %q", got, expected) + } +} + +type customWriterRecorder struct { + w io.Writer + *httptest.ResponseRecorder +} + +func (r *customWriterRecorder) Write(p []byte) (n int, err error) { + return r.w.Write(p) +} + +type limitWriter struct { + w io.Writer + n int +} + +func (w *limitWriter) Write(p []byte) (n int, err error) { + if len(p) > w.n { + p = p[:w.n] + } + if len(p) > 0 { + n, err = w.w.Write(p) + w.n -= n + } + if w.n == 0 { + err = errors.New("past write limit") + } + return +} + +// If there's an error copying the child's output to the parent, test +// that we kill the child. +func TestKillChildAfterCopyError(t *testing.T) { + if runtime.GOOS == "nacl" { + t.Skip("skipping on nacl") + } + + defer func() { testHookStartProcess = nil }() + proc := make(chan *os.Process, 1) + testHookStartProcess = func(p *os.Process) { + proc <- p + } + + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + req, _ := http.NewRequest("GET", "http://example.com/test.cgi?write-forever=1", nil) + rec := httptest.NewRecorder() + var out bytes.Buffer + const writeLen = 50 << 10 + rw := &customWriterRecorder{&limitWriter{&out, writeLen}, rec} + + donec := make(chan bool, 1) + go func() { + h.ServeHTTP(rw, req) + donec <- true + }() + + select { + case <-donec: + if out.Len() != writeLen || out.Bytes()[0] != 'a' { + t.Errorf("unexpected output: %q", out.Bytes()) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout. ServeHTTP hung and didn't kill the child process?") + select { + case p := <-proc: + p.Kill() + t.Logf("killed process") + default: + t.Logf("didn't kill process") + } + } +} + +// Test that a child handler writing only headers works. +// golang.org/issue/7196 +func TestChildOnlyHeaders(t *testing.T) { + if runtime.GOOS == "nacl" { + t.Skip("skipping on nacl") + } + + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + expectedMap := map[string]string{ + "_body": "", + } + replay := runCgiTest(t, h, "GET /test.go?no-body=1 HTTP/1.0\nHost: example.com\n\n", expectedMap) + if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected { + t.Errorf("got a X-Test-Header of %q; expected %q", got, expected) + } +} + +// golang.org/issue/7198 +func Test500WithNoHeaders(t *testing.T) { want500Test(t, "/immediate-disconnect") } +func Test500WithNoContentType(t *testing.T) { want500Test(t, "/no-content-type") } +func Test500WithEmptyHeaders(t *testing.T) { want500Test(t, "/empty-headers") } + +func want500Test(t *testing.T, path string) { + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + expectedMap := map[string]string{ + "_body": "", + } + replay := runCgiTest(t, h, "GET "+path+" HTTP/1.0\nHost: example.com\n\n", expectedMap) + if replay.Code != 500 { + t.Errorf("Got code %d; want 500", replay.Code) + } +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +// Note: not actually a test. +func TestBeChildCGIProcess(t *testing.T) { + if os.Getenv("REQUEST_METHOD") == "" { + // Not in a CGI environment; skipping test. + return + } + switch os.Getenv("REQUEST_URI") { + case "/immediate-disconnect": + os.Exit(0) + case "/no-content-type": + fmt.Printf("Content-Length: 6\n\nHello\n") + os.Exit(0) + case "/empty-headers": + fmt.Printf("\nHello") + os.Exit(0) + } + Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("X-Test-Header", "X-Test-Value") + req.ParseForm() + if req.FormValue("no-body") == "1" { + return + } + if req.FormValue("write-forever") == "1" { + io.Copy(rw, neverEnding('a')) + for { + time.Sleep(5 * time.Second) // hang forever, until killed + } + } + fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n") + for k, vv := range req.Form { + for _, v := range vv { + fmt.Fprintf(rw, "param-%s=%s\n", k, v) + } + } + for _, kv := range os.Environ() { + fmt.Fprintf(rw, "env-%s\n", kv) + } + })) + os.Exit(0) +} diff --git a/src/net/http/cgi/plan9_test.go b/src/net/http/cgi/plan9_test.go new file mode 100644 index 000000000..c8235831b --- /dev/null +++ b/src/net/http/cgi/plan9_test.go @@ -0,0 +1,18 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build plan9 + +package cgi + +import ( + "os" + "strconv" + "testing" +) + +func isProcessRunning(t *testing.T, pid int) bool { + _, err := os.Stat("/proc/" + strconv.Itoa(pid)) + return err == nil +} diff --git a/src/net/http/cgi/posix_test.go b/src/net/http/cgi/posix_test.go new file mode 100644 index 000000000..5ff9e7d5e --- /dev/null +++ b/src/net/http/cgi/posix_test.go @@ -0,0 +1,21 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9 + +package cgi + +import ( + "os" + "syscall" + "testing" +) + +func isProcessRunning(t *testing.T, pid int) bool { + p, err := os.FindProcess(pid) + if err != nil { + return false + } + return p.Signal(syscall.Signal(0)) == nil +} diff --git a/src/net/http/cgi/testdata/test.cgi b/src/net/http/cgi/testdata/test.cgi new file mode 100755 index 000000000..3214df6f0 --- /dev/null +++ b/src/net/http/cgi/testdata/test.cgi @@ -0,0 +1,91 @@ +#!/usr/bin/perl +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. +# +# Test script run as a child process under cgi_test.go + +use strict; +use Cwd; + +binmode STDOUT; + +my $q = MiniCGI->new; +my $params = $q->Vars; + +if ($params->{"loc"}) { + print "Location: $params->{loc}\r\n\r\n"; + exit(0); +} + +print "Content-Type: text/html\r\n"; +print "X-CGI-Pid: $$\r\n"; +print "X-Test-Header: X-Test-Value\r\n"; +print "\r\n"; + +if ($params->{"bigresponse"}) { + # 17 MB, for OS X: golang.org/issue/4958 + for (1..(17 * 1024)) { + print "A" x 1024, "\r\n"; + } + exit 0; +} + +print "test=Hello CGI\r\n"; + +foreach my $k (sort keys %$params) { + print "param-$k=$params->{$k}\r\n"; +} + +foreach my $k (sort keys %ENV) { + my $clean_env = $ENV{$k}; + $clean_env =~ s/[\n\r]//g; + print "env-$k=$clean_env\r\n"; +} + +# NOTE: msys perl returns /c/go/src/... not C:\go\.... +my $dir = getcwd(); +if ($^O eq 'MSWin32' || $^O eq 'msys') { + if ($dir =~ /^.:/) { + $dir =~ s!/!\\!g; + } else { + my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe'; + $cmd =~ s!\\!/!g; + $dir = `$cmd /c cd`; + chomp $dir; + } +} +print "cwd=$dir\r\n"; + +# A minimal version of CGI.pm, for people without the perl-modules +# package installed. (CGI.pm used to be part of the Perl core, but +# some distros now bundle perl-base and perl-modules separately...) +package MiniCGI; + +sub new { + my $class = shift; + return bless {}, $class; +} + +sub Vars { + my $self = shift; + my $pairs; + if ($ENV{CONTENT_LENGTH}) { + $pairs = do { local $/; <STDIN> }; + } else { + $pairs = $ENV{QUERY_STRING}; + } + my $vars = {}; + foreach my $kv (split(/&/, $pairs)) { + my ($k, $v) = split(/=/, $kv, 2); + $vars->{_urldecode($k)} = _urldecode($v); + } + return $vars; +} + +sub _urldecode { + my $v = shift; + $v =~ tr/+/ /; + $v =~ s/%([a-fA-F0-9][a-fA-F0-9])/pack("C", hex($1))/eg; + return $v; +} diff --git a/src/net/http/client.go b/src/net/http/client.go new file mode 100644 index 000000000..ce884d1f0 --- /dev/null +++ b/src/net/http/client.go @@ -0,0 +1,511 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP client. See RFC 2616. +// +// This is the high-level Client interface. +// The low-level implementation is in transport.go. + +package http + +import ( + "encoding/base64" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net/url" + "strings" + "sync" + "time" +) + +// A Client is an HTTP client. Its zero value (DefaultClient) is a +// usable client that uses DefaultTransport. +// +// The Client's Transport typically has internal state (cached TCP +// connections), so Clients should be reused instead of created as +// needed. Clients are safe for concurrent use by multiple goroutines. +// +// A Client is higher-level than a RoundTripper (such as Transport) +// and additionally handles HTTP details such as cookies and +// redirects. +type Client struct { + // Transport specifies the mechanism by which individual + // HTTP requests are made. + // If nil, DefaultTransport is used. + Transport RoundTripper + + // CheckRedirect specifies the policy for handling redirects. + // If CheckRedirect is not nil, the client calls it before + // following an HTTP redirect. The arguments req and via are + // the upcoming request and the requests made already, oldest + // first. If CheckRedirect returns an error, the Client's Get + // method returns both the previous Response and + // CheckRedirect's error (wrapped in a url.Error) instead of + // issuing the Request req. + // + // If CheckRedirect is nil, the Client uses its default policy, + // which is to stop after 10 consecutive requests. + CheckRedirect func(req *Request, via []*Request) error + + // Jar specifies the cookie jar. + // If Jar is nil, cookies are not sent in requests and ignored + // in responses. + Jar CookieJar + + // Timeout specifies a time limit for requests made by this + // Client. The timeout includes connection time, any + // redirects, and reading the response body. The timer remains + // running after Get, Head, Post, or Do return and will + // interrupt reading of the Response.Body. + // + // A Timeout of zero means no timeout. + // + // The Client's Transport must support the CancelRequest + // method or Client will return errors when attempting to make + // a request with Get, Head, Post, or Do. Client's default + // Transport (DefaultTransport) supports CancelRequest. + Timeout time.Duration +} + +// DefaultClient is the default Client and is used by Get, Head, and Post. +var DefaultClient = &Client{} + +// RoundTripper is an interface representing the ability to execute a +// single HTTP transaction, obtaining the Response for a given Request. +// +// A RoundTripper must be safe for concurrent use by multiple +// goroutines. +type RoundTripper interface { + // RoundTrip executes a single HTTP transaction, returning + // the Response for the request req. RoundTrip should not + // attempt to interpret the response. In particular, + // RoundTrip must return err == nil if it obtained a response, + // regardless of the response's HTTP status code. A non-nil + // err should be reserved for failure to obtain a response. + // Similarly, RoundTrip should not attempt to handle + // higher-level protocol details such as redirects, + // authentication, or cookies. + // + // RoundTrip should not modify the request, except for + // consuming and closing the Body, including on errors. The + // request's URL and Header fields are guaranteed to be + // initialized. + RoundTrip(*Request) (*Response, error) +} + +// Given a string of the form "host", "host:port", or "[ipv6::address]:port", +// return true if the string includes a port. +func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") } + +// refererForURL returns a referer without any authentication info or +// an empty string if lastReq scheme is https and newReq scheme is http. +func refererForURL(lastReq, newReq *url.URL) string { + // https://tools.ietf.org/html/rfc7231#section-5.5.2 + // "Clients SHOULD NOT include a Referer header field in a + // (non-secure) HTTP request if the referring page was + // transferred with a secure protocol." + if lastReq.Scheme == "https" && newReq.Scheme == "http" { + return "" + } + referer := lastReq.String() + if lastReq.User != nil { + // This is not very efficient, but is the best we can + // do without: + // - introducing a new method on URL + // - creating a race condition + // - copying the URL struct manually, which would cause + // maintenance problems down the line + auth := lastReq.User.String() + "@" + referer = strings.Replace(referer, auth, "", 1) + } + return referer +} + +// Used in Send to implement io.ReadCloser by bundling together the +// bufio.Reader through which we read the response, and the underlying +// network connection. +type readClose struct { + io.Reader + io.Closer +} + +func (c *Client) send(req *Request) (*Response, error) { + if c.Jar != nil { + for _, cookie := range c.Jar.Cookies(req.URL) { + req.AddCookie(cookie) + } + } + resp, err := send(req, c.transport()) + if err != nil { + return nil, err + } + if c.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + c.Jar.SetCookies(req.URL, rc) + } + } + return resp, err +} + +// Do sends an HTTP request and returns an HTTP response, following +// policy (e.g. redirects, cookies, auth) as configured on the client. +// +// An error is returned if caused by client policy (such as +// CheckRedirect), or if there was an HTTP protocol error. +// A non-2xx response doesn't cause an error. +// +// When err is nil, resp always contains a non-nil resp.Body. +// +// Callers should close resp.Body when done reading from it. If +// resp.Body is not closed, the Client's underlying RoundTripper +// (typically Transport) may not be able to re-use a persistent TCP +// connection to the server for a subsequent "keep-alive" request. +// +// The request Body, if non-nil, will be closed by the underlying +// Transport, even on errors. +// +// Generally Get, Post, or PostForm will be used instead of Do. +func (c *Client) Do(req *Request) (resp *Response, err error) { + if req.Method == "GET" || req.Method == "HEAD" { + return c.doFollowingRedirects(req, shouldRedirectGet) + } + if req.Method == "POST" || req.Method == "PUT" { + return c.doFollowingRedirects(req, shouldRedirectPost) + } + return c.send(req) +} + +func (c *Client) transport() RoundTripper { + if c.Transport != nil { + return c.Transport + } + return DefaultTransport +} + +// send issues an HTTP request. +// Caller should close resp.Body when done reading from it. +func send(req *Request, t RoundTripper) (resp *Response, err error) { + if t == nil { + req.closeBody() + return nil, errors.New("http: no Client.Transport or DefaultTransport") + } + + if req.URL == nil { + req.closeBody() + return nil, errors.New("http: nil Request.URL") + } + + if req.RequestURI != "" { + req.closeBody() + return nil, errors.New("http: Request.RequestURI can't be set in client requests.") + } + + // Most the callers of send (Get, Post, et al) don't need + // Headers, leaving it uninitialized. We guarantee to the + // Transport that this has been initialized, though. + if req.Header == nil { + req.Header = make(Header) + } + + if u := req.URL.User; u != nil { + username := u.Username() + password, _ := u.Password() + req.Header.Set("Authorization", "Basic "+basicAuth(username, password)) + } + resp, err = t.RoundTrip(req) + if err != nil { + if resp != nil { + log.Printf("RoundTripper returned a response & error; ignoring response") + } + return nil, err + } + return resp, nil +} + +// See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt +// "To receive authorization, the client sends the userid and password, +// separated by a single colon (":") character, within a base64 +// encoded string in the credentials." +// It is not meant to be urlencoded. +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +// True if the specified HTTP status code is one for which the Get utility should +// automatically redirect. +func shouldRedirectGet(statusCode int) bool { + switch statusCode { + case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect: + return true + } + return false +} + +// True if the specified HTTP status code is one for which the Post utility should +// automatically redirect. +func shouldRedirectPost(statusCode int) bool { + switch statusCode { + case StatusFound, StatusSeeOther: + return true + } + return false +} + +// Get issues a GET to the specified URL. If the response is one of the following +// redirect codes, Get follows the redirect, up to a maximum of 10 redirects: +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// +// An error is returned if there were too many redirects or if there +// was an HTTP protocol error. A non-2xx response doesn't cause an +// error. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +// +// Get is a wrapper around DefaultClient.Get. +func Get(url string) (resp *Response, err error) { + return DefaultClient.Get(url) +} + +// Get issues a GET to the specified URL. If the response is one of the +// following redirect codes, Get follows the redirect after calling the +// Client's CheckRedirect function. +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// +// An error is returned if the Client's CheckRedirect function fails +// or if there was an HTTP protocol error. A non-2xx response doesn't +// cause an error. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +func (c *Client) Get(url string) (resp *Response, err error) { + req, err := NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return c.doFollowingRedirects(req, shouldRedirectGet) +} + +func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bool) (resp *Response, err error) { + var base *url.URL + redirectChecker := c.CheckRedirect + if redirectChecker == nil { + redirectChecker = defaultCheckRedirect + } + var via []*Request + + if ireq.URL == nil { + ireq.closeBody() + return nil, errors.New("http: nil Request.URL") + } + + var reqmu sync.Mutex // guards req + req := ireq + + var timer *time.Timer + if c.Timeout > 0 { + type canceler interface { + CancelRequest(*Request) + } + tr, ok := c.transport().(canceler) + if !ok { + return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport()) + } + timer = time.AfterFunc(c.Timeout, func() { + reqmu.Lock() + defer reqmu.Unlock() + tr.CancelRequest(req) + }) + } + + urlStr := "" // next relative or absolute URL to fetch (after first request) + redirectFailed := false + for redirect := 0; ; redirect++ { + if redirect != 0 { + nreq := new(Request) + nreq.Method = ireq.Method + if ireq.Method == "POST" || ireq.Method == "PUT" { + nreq.Method = "GET" + } + nreq.Header = make(Header) + nreq.URL, err = base.Parse(urlStr) + if err != nil { + break + } + if len(via) > 0 { + // Add the Referer header. + lastReq := via[len(via)-1] + if ref := refererForURL(lastReq.URL, nreq.URL); ref != "" { + nreq.Header.Set("Referer", ref) + } + + err = redirectChecker(nreq, via) + if err != nil { + redirectFailed = true + break + } + } + reqmu.Lock() + req = nreq + reqmu.Unlock() + } + + urlStr = req.URL.String() + if resp, err = c.send(req); err != nil { + break + } + + if shouldRedirect(resp.StatusCode) { + // Read the body if small so underlying TCP connection will be re-used. + // No need to check for errors: if it fails, Transport won't reuse it anyway. + const maxBodySlurpSize = 2 << 10 + if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize { + io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize) + } + resp.Body.Close() + if urlStr = resp.Header.Get("Location"); urlStr == "" { + err = errors.New(fmt.Sprintf("%d response missing Location header", resp.StatusCode)) + break + } + base = req.URL + via = append(via, req) + continue + } + if timer != nil { + resp.Body = &cancelTimerBody{timer, resp.Body} + } + return resp, nil + } + + method := ireq.Method + urlErr := &url.Error{ + Op: method[0:1] + strings.ToLower(method[1:]), + URL: urlStr, + Err: err, + } + + if redirectFailed { + // Special case for Go 1 compatibility: return both the response + // and an error if the CheckRedirect function failed. + // See http://golang.org/issue/3795 + return resp, urlErr + } + + if resp != nil { + resp.Body.Close() + } + return nil, urlErr +} + +func defaultCheckRedirect(req *Request, via []*Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + return nil +} + +// Post issues a POST to the specified URL. +// +// Caller should close resp.Body when done reading from it. +// +// Post is a wrapper around DefaultClient.Post +func Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { + return DefaultClient.Post(url, bodyType, body) +} + +// Post issues a POST to the specified URL. +// +// Caller should close resp.Body when done reading from it. +// +// If the provided body is also an io.Closer, it is closed after the +// request. +func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { + req, err := NewRequest("POST", url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", bodyType) + return c.doFollowingRedirects(req, shouldRedirectPost) +} + +// PostForm issues a POST to the specified URL, with data's keys and +// values URL-encoded as the request body. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +// +// PostForm is a wrapper around DefaultClient.PostForm +func PostForm(url string, data url.Values) (resp *Response, err error) { + return DefaultClient.PostForm(url, data) +} + +// PostForm issues a POST to the specified URL, +// with data's keys and values urlencoded as the request body. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) { + return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +} + +// Head issues a HEAD to the specified URL. If the response is one of the +// following redirect codes, Head follows the redirect after calling the +// Client's CheckRedirect function. +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// +// Head is a wrapper around DefaultClient.Head +func Head(url string) (resp *Response, err error) { + return DefaultClient.Head(url) +} + +// Head issues a HEAD to the specified URL. If the response is one of the +// following redirect codes, Head follows the redirect after calling the +// Client's CheckRedirect function. +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +func (c *Client) Head(url string) (resp *Response, err error) { + req, err := NewRequest("HEAD", url, nil) + if err != nil { + return nil, err + } + return c.doFollowingRedirects(req, shouldRedirectGet) +} + +type cancelTimerBody struct { + t *time.Timer + rc io.ReadCloser +} + +func (b *cancelTimerBody) Read(p []byte) (n int, err error) { + n, err = b.rc.Read(p) + if err == io.EOF { + b.t.Stop() + } + return +} + +func (b *cancelTimerBody) Close() error { + err := b.rc.Close() + b.t.Stop() + return err +} diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go new file mode 100644 index 000000000..56b6563c4 --- /dev/null +++ b/src/net/http/client_test.go @@ -0,0 +1,1075 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for client.go + +package http_test + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net" + . "net/http" + "net/http/httptest" + "net/url" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Last-Modified", "sometime") + fmt.Fprintf(w, "User-agent: go\nDisallow: /something/") +}) + +// pedanticReadAll works like ioutil.ReadAll but additionally +// verifies that r obeys the documented io.Reader contract. +func pedanticReadAll(r io.Reader) (b []byte, err error) { + var bufa [64]byte + buf := bufa[:] + for { + n, err := r.Read(buf) + if n == 0 && err == nil { + return nil, fmt.Errorf("Read: n=0 with err=nil") + } + b = append(b, buf[:n]...) + if err == io.EOF { + n, err := r.Read(buf) + if n != 0 || err != io.EOF { + return nil, fmt.Errorf("Read: n=%d err=%#v after EOF", n, err) + } + return b, nil + } + if err != nil { + return b, err + } + } +} + +type chanWriter chan string + +func (w chanWriter) Write(p []byte) (n int, err error) { + w <- string(p) + return len(p), nil +} + +func TestClient(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(robotsTxtHandler) + defer ts.Close() + + r, err := Get(ts.URL) + var b []byte + if err == nil { + b, err = pedanticReadAll(r.Body) + r.Body.Close() + } + if err != nil { + t.Error(err) + } else if s := string(b); !strings.HasPrefix(s, "User-agent:") { + t.Errorf("Incorrect page body (did not begin with User-agent): %q", s) + } +} + +func TestClientHead(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(robotsTxtHandler) + defer ts.Close() + + r, err := Head(ts.URL) + if err != nil { + t.Fatal(err) + } + if _, ok := r.Header["Last-Modified"]; !ok { + t.Error("Last-Modified header not found.") + } +} + +type recordingTransport struct { + req *Request +} + +func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) { + t.req = req + return nil, errors.New("dummy impl") +} + +func TestGetRequestFormat(t *testing.T) { + defer afterTest(t) + tr := &recordingTransport{} + client := &Client{Transport: tr} + url := "http://dummy.faketld/" + client.Get(url) // Note: doesn't hit network + if tr.req.Method != "GET" { + t.Errorf("expected method %q; got %q", "GET", tr.req.Method) + } + if tr.req.URL.String() != url { + t.Errorf("expected URL %q; got %q", url, tr.req.URL.String()) + } + if tr.req.Header == nil { + t.Errorf("expected non-nil request Header") + } +} + +func TestPostRequestFormat(t *testing.T) { + defer afterTest(t) + tr := &recordingTransport{} + client := &Client{Transport: tr} + + url := "http://dummy.faketld/" + json := `{"key":"value"}` + b := strings.NewReader(json) + client.Post(url, "application/json", b) // Note: doesn't hit network + + if tr.req.Method != "POST" { + t.Errorf("got method %q, want %q", tr.req.Method, "POST") + } + if tr.req.URL.String() != url { + t.Errorf("got URL %q, want %q", tr.req.URL.String(), url) + } + if tr.req.Header == nil { + t.Fatalf("expected non-nil request Header") + } + if tr.req.Close { + t.Error("got Close true, want false") + } + if g, e := tr.req.ContentLength, int64(len(json)); g != e { + t.Errorf("got ContentLength %d, want %d", g, e) + } +} + +func TestPostFormRequestFormat(t *testing.T) { + defer afterTest(t) + tr := &recordingTransport{} + client := &Client{Transport: tr} + + urlStr := "http://dummy.faketld/" + form := make(url.Values) + form.Set("foo", "bar") + form.Add("foo", "bar2") + form.Set("bar", "baz") + client.PostForm(urlStr, form) // Note: doesn't hit network + + if tr.req.Method != "POST" { + t.Errorf("got method %q, want %q", tr.req.Method, "POST") + } + if tr.req.URL.String() != urlStr { + t.Errorf("got URL %q, want %q", tr.req.URL.String(), urlStr) + } + if tr.req.Header == nil { + t.Fatalf("expected non-nil request Header") + } + if g, e := tr.req.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; g != e { + t.Errorf("got Content-Type %q, want %q", g, e) + } + if tr.req.Close { + t.Error("got Close true, want false") + } + // Depending on map iteration, body can be either of these. + expectedBody := "foo=bar&foo=bar2&bar=baz" + expectedBody1 := "bar=baz&foo=bar&foo=bar2" + if g, e := tr.req.ContentLength, int64(len(expectedBody)); g != e { + t.Errorf("got ContentLength %d, want %d", g, e) + } + bodyb, err := ioutil.ReadAll(tr.req.Body) + if err != nil { + t.Fatalf("ReadAll on req.Body: %v", err) + } + if g := string(bodyb); g != expectedBody && g != expectedBody1 { + t.Errorf("got body %q, want %q or %q", g, expectedBody, expectedBody1) + } +} + +func TestClientRedirects(t *testing.T) { + defer afterTest(t) + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + n, _ := strconv.Atoi(r.FormValue("n")) + // Test Referer header. (7 is arbitrary position to test at) + if n == 7 { + if g, e := r.Referer(), ts.URL+"/?n=6"; e != g { + t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g) + } + } + if n < 15 { + Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound) + return + } + fmt.Fprintf(w, "n=%d", n) + })) + defer ts.Close() + + c := &Client{} + _, err := c.Get(ts.URL) + if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client Get, expected error %q, got %q", e, g) + } + + // HEAD request should also have the ability to follow redirects. + _, err = c.Head(ts.URL) + if e, g := "Head /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client Head, expected error %q, got %q", e, g) + } + + // Do should also follow redirects. + greq, _ := NewRequest("GET", ts.URL, nil) + _, err = c.Do(greq) + if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client Do, expected error %q, got %q", e, g) + } + + var checkErr error + var lastVia []*Request + c = &Client{CheckRedirect: func(_ *Request, via []*Request) error { + lastVia = via + return checkErr + }} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + res.Body.Close() + finalUrl := res.Request.URL.String() + if e, g := "<nil>", fmt.Sprintf("%v", err); e != g { + t.Errorf("with custom client, expected error %q, got %q", e, g) + } + if !strings.HasSuffix(finalUrl, "/?n=15") { + t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl) + } + if e, g := 15, len(lastVia); e != g { + t.Errorf("expected lastVia to have contained %d elements; got %d", e, g) + } + + checkErr = errors.New("no redirects allowed") + res, err = c.Get(ts.URL) + if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr { + t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err) + } + if res == nil { + t.Fatalf("Expected a non-nil Response on CheckRedirect failure (http://golang.org/issue/3795)") + } + res.Body.Close() + if res.Header.Get("Location") == "" { + t.Errorf("no Location header in Response") + } +} + +func TestPostRedirects(t *testing.T) { + defer afterTest(t) + var log struct { + sync.Mutex + bytes.Buffer + } + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + log.Lock() + fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI) + log.Unlock() + if v := r.URL.Query().Get("code"); v != "" { + code, _ := strconv.Atoi(v) + if code/100 == 3 { + w.Header().Set("Location", ts.URL) + } + w.WriteHeader(code) + } + })) + defer ts.Close() + tests := []struct { + suffix string + want int // response code + }{ + {"/", 200}, + {"/?code=301", 301}, + {"/?code=302", 200}, + {"/?code=303", 200}, + {"/?code=404", 404}, + } + for _, tt := range tests { + res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content")) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != tt.want { + t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want) + } + } + log.Lock() + got := log.String() + log.Unlock() + want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 " + if got != want { + t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want) + } +} + +var expectedCookies = []*Cookie{ + {Name: "ChocolateChip", Value: "tasty"}, + {Name: "First", Value: "Hit"}, + {Name: "Second", Value: "Hit"}, +} + +var echoCookiesRedirectHandler = HandlerFunc(func(w ResponseWriter, r *Request) { + for _, cookie := range r.Cookies() { + SetCookie(w, cookie) + } + if r.URL.Path == "/" { + SetCookie(w, expectedCookies[1]) + Redirect(w, r, "/second", StatusMovedPermanently) + } else { + SetCookie(w, expectedCookies[2]) + w.Write([]byte("hello")) + } +}) + +func TestClientSendsCookieFromJar(t *testing.T) { + tr := &recordingTransport{} + client := &Client{Transport: tr} + client.Jar = &TestJar{perURL: make(map[string][]*Cookie)} + us := "http://dummy.faketld/" + u, _ := url.Parse(us) + client.Jar.SetCookies(u, expectedCookies) + + client.Get(us) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) + + client.Head(us) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) + + client.Post(us, "text/plain", strings.NewReader("body")) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) + + client.PostForm(us, url.Values{}) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) + + req, _ := NewRequest("GET", us, nil) + client.Do(req) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) + + req, _ = NewRequest("POST", us, nil) + client.Do(req) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) +} + +// Just enough correctness for our redirect tests. Uses the URL.Host as the +// scope of all cookies. +type TestJar struct { + m sync.Mutex + perURL map[string][]*Cookie +} + +func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) { + j.m.Lock() + defer j.m.Unlock() + if j.perURL == nil { + j.perURL = make(map[string][]*Cookie) + } + j.perURL[u.Host] = cookies +} + +func (j *TestJar) Cookies(u *url.URL) []*Cookie { + j.m.Lock() + defer j.m.Unlock() + return j.perURL[u.Host] +} + +func TestRedirectCookiesJar(t *testing.T) { + defer afterTest(t) + var ts *httptest.Server + ts = httptest.NewServer(echoCookiesRedirectHandler) + defer ts.Close() + c := &Client{ + Jar: new(TestJar), + } + u, _ := url.Parse(ts.URL) + c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]}) + resp, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + resp.Body.Close() + matchReturnedCookies(t, expectedCookies, resp.Cookies()) +} + +func matchReturnedCookies(t *testing.T, expected, given []*Cookie) { + if len(given) != len(expected) { + t.Logf("Received cookies: %v", given) + t.Errorf("Expected %d cookies, got %d", len(expected), len(given)) + } + for _, ec := range expected { + foundC := false + for _, c := range given { + if ec.Name == c.Name && ec.Value == c.Value { + foundC = true + break + } + } + if !foundC { + t.Errorf("Missing cookie %v", ec) + } + } +} + +func TestJarCalls(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + pathSuffix := r.RequestURI[1:] + if r.RequestURI == "/nosetcookie" { + return // dont set cookies for this path + } + SetCookie(w, &Cookie{Name: "name" + pathSuffix, Value: "val" + pathSuffix}) + if r.RequestURI == "/" { + Redirect(w, r, "http://secondhost.fake/secondpath", 302) + } + })) + defer ts.Close() + jar := new(RecordingJar) + c := &Client{ + Jar: jar, + Transport: &Transport{ + Dial: func(_ string, _ string) (net.Conn, error) { + return net.Dial("tcp", ts.Listener.Addr().String()) + }, + }, + } + _, err := c.Get("http://firsthost.fake/") + if err != nil { + t.Fatal(err) + } + _, err = c.Get("http://firsthost.fake/nosetcookie") + if err != nil { + t.Fatal(err) + } + got := jar.log.String() + want := `Cookies("http://firsthost.fake/") +SetCookie("http://firsthost.fake/", [name=val]) +Cookies("http://secondhost.fake/secondpath") +SetCookie("http://secondhost.fake/secondpath", [namesecondpath=valsecondpath]) +Cookies("http://firsthost.fake/nosetcookie") +` + if got != want { + t.Errorf("Got Jar calls:\n%s\nWant:\n%s", got, want) + } +} + +// RecordingJar keeps a log of calls made to it, without +// tracking any cookies. +type RecordingJar struct { + mu sync.Mutex + log bytes.Buffer +} + +func (j *RecordingJar) SetCookies(u *url.URL, cookies []*Cookie) { + j.logf("SetCookie(%q, %v)\n", u, cookies) +} + +func (j *RecordingJar) Cookies(u *url.URL) []*Cookie { + j.logf("Cookies(%q)\n", u) + return nil +} + +func (j *RecordingJar) logf(format string, args ...interface{}) { + j.mu.Lock() + defer j.mu.Unlock() + fmt.Fprintf(&j.log, format, args...) +} + +func TestStreamingGet(t *testing.T) { + defer afterTest(t) + say := make(chan string) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.(Flusher).Flush() + for str := range say { + w.Write([]byte(str)) + w.(Flusher).Flush() + } + })) + defer ts.Close() + + c := &Client{} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + var buf [10]byte + for _, str := range []string{"i", "am", "also", "known", "as", "comet"} { + say <- str + n, err := io.ReadFull(res.Body, buf[0:len(str)]) + if err != nil { + t.Fatalf("ReadFull on %q: %v", str, err) + } + if n != len(str) { + t.Fatalf("Receiving %q, only read %d bytes", str, n) + } + got := string(buf[0:n]) + if got != str { + t.Fatalf("Expected %q, got %q", str, got) + } + } + close(say) + _, err = io.ReadFull(res.Body, buf[0:1]) + if err != io.EOF { + t.Fatalf("at end expected EOF, got %v", err) + } +} + +type writeCountingConn struct { + net.Conn + count *int +} + +func (c *writeCountingConn) Write(p []byte) (int, error) { + *c.count++ + return c.Conn.Write(p) +} + +// TestClientWrites verifies that client requests are buffered and we +// don't send a TCP packet per line of the http request + body. +func TestClientWrites(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + })) + defer ts.Close() + + writes := 0 + dialer := func(netz string, addr string) (net.Conn, error) { + c, err := net.Dial(netz, addr) + if err == nil { + c = &writeCountingConn{c, &writes} + } + return c, err + } + c := &Client{Transport: &Transport{Dial: dialer}} + + _, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if writes != 1 { + t.Errorf("Get request did %d Write calls, want 1", writes) + } + + writes = 0 + _, err = c.PostForm(ts.URL, url.Values{"foo": {"bar"}}) + if err != nil { + t.Fatal(err) + } + if writes != 1 { + t.Errorf("Post request did %d Write calls, want 1", writes) + } +} + +func TestClientInsecureTransport(t *testing.T) { + defer afterTest(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("Hello")) + })) + errc := make(chanWriter, 10) // but only expecting 1 + ts.Config.ErrorLog = log.New(errc, "", 0) + defer ts.Close() + + // TODO(bradfitz): add tests for skipping hostname checks too? + // would require a new cert for testing, and probably + // redundant with these tests. + for _, insecure := range []bool{true, false} { + tr := &Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: insecure, + }, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + res, err := c.Get(ts.URL) + if (err == nil) != insecure { + t.Errorf("insecure=%v: got unexpected err=%v", insecure, err) + } + if res != nil { + res.Body.Close() + } + } + + select { + case v := <-errc: + if !strings.Contains(v, "TLS handshake error") { + t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout waiting for logged error") + } + +} + +func TestClientErrorWithRequestURI(t *testing.T) { + defer afterTest(t) + req, _ := NewRequest("GET", "http://localhost:1234/", nil) + req.RequestURI = "/this/field/is/illegal/and/should/error/" + _, err := DefaultClient.Do(req) + if err == nil { + t.Fatalf("expected an error") + } + if !strings.Contains(err.Error(), "RequestURI") { + t.Errorf("wanted error mentioning RequestURI; got error: %v", err) + } +} + +func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport { + certs := x509.NewCertPool() + for _, c := range ts.TLS.Certificates { + roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) + if err != nil { + t.Fatalf("error parsing server's root cert: %v", err) + } + for _, root := range roots { + certs.AddCert(root) + } + } + return &Transport{ + TLSClientConfig: &tls.Config{RootCAs: certs}, + } +} + +func TestClientWithCorrectTLSServerName(t *testing.T) { + defer afterTest(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.TLS.ServerName != "127.0.0.1" { + t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName) + } + })) + defer ts.Close() + + c := &Client{Transport: newTLSTransport(t, ts)} + if _, err := c.Get(ts.URL); err != nil { + t.Fatalf("expected successful TLS connection, got error: %v", err) + } +} + +func TestClientWithIncorrectTLSServerName(t *testing.T) { + defer afterTest(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer ts.Close() + errc := make(chanWriter, 10) // but only expecting 1 + ts.Config.ErrorLog = log.New(errc, "", 0) + + trans := newTLSTransport(t, ts) + trans.TLSClientConfig.ServerName = "badserver" + c := &Client{Transport: trans} + _, err := c.Get(ts.URL) + if err == nil { + t.Fatalf("expected an error") + } + if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") { + t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err) + } + select { + case v := <-errc: + if !strings.Contains(v, "TLS handshake error") { + t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout waiting for logged error") + } +} + +// Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName +// when not empty. +// +// tls.Config.ServerName (non-empty, set to "example.com") takes +// precedence over "some-other-host.tld" which previously incorrectly +// took precedence. We don't actually connect to (or even resolve) +// "some-other-host.tld", though, because of the Transport.Dial hook. +// +// The httptest.Server has a cert with "example.com" as its name. +func TestTransportUsesTLSConfigServerName(t *testing.T) { + defer afterTest(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("Hello")) + })) + defer ts.Close() + + tr := newTLSTransport(t, ts) + tr.TLSClientConfig.ServerName = "example.com" // one of httptest's Server cert names + tr.Dial = func(netw, addr string) (net.Conn, error) { + return net.Dial(netw, ts.Listener.Addr().String()) + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + res, err := c.Get("https://some-other-host.tld/") + if err != nil { + t.Fatal(err) + } + res.Body.Close() +} + +func TestResponseSetsTLSConnectionState(t *testing.T) { + defer afterTest(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("Hello")) + })) + defer ts.Close() + + tr := newTLSTransport(t, ts) + tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA} + tr.Dial = func(netw, addr string) (net.Conn, error) { + return net.Dial(netw, ts.Listener.Addr().String()) + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + res, err := c.Get("https://example.com/") + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.TLS == nil { + t.Fatal("Response didn't set TLS Connection State.") + } + if got, want := res.TLS.CipherSuite, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA; got != want { + t.Errorf("TLS Cipher Suite = %d; want %d", got, want) + } +} + +// Verify Response.ContentLength is populated. http://golang.org/issue/4126 +func TestClientHeadContentLength(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if v := r.FormValue("cl"); v != "" { + w.Header().Set("Content-Length", v) + } + })) + defer ts.Close() + tests := []struct { + suffix string + want int64 + }{ + {"/?cl=1234", 1234}, + {"/?cl=0", 0}, + {"", -1}, + } + for _, tt := range tests { + req, _ := NewRequest("HEAD", ts.URL+tt.suffix, nil) + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if res.ContentLength != tt.want { + t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want) + } + bs, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if len(bs) != 0 { + t.Errorf("Unexpected content: %q", bs) + } + } +} + +func TestEmptyPasswordAuth(t *testing.T) { + defer afterTest(t) + gopher := "gopher" + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Basic ") { + encoded := auth[6:] + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + t.Fatal(err) + } + expected := gopher + ":" + s := string(decoded) + if expected != s { + t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected) + } + } else { + t.Errorf("Invalid auth %q", auth) + } + })) + defer ts.Close() + c := &Client{} + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + req.URL.User = url.User(gopher) + resp, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() +} + +func TestBasicAuth(t *testing.T) { + defer afterTest(t) + tr := &recordingTransport{} + client := &Client{Transport: tr} + + url := "http://My%20User:My%20Pass@dummy.faketld/" + expected := "My User:My Pass" + client.Get(url) + + if tr.req.Method != "GET" { + t.Errorf("got method %q, want %q", tr.req.Method, "GET") + } + if tr.req.URL.String() != url { + t.Errorf("got URL %q, want %q", tr.req.URL.String(), url) + } + if tr.req.Header == nil { + t.Fatalf("expected non-nil request Header") + } + auth := tr.req.Header.Get("Authorization") + if strings.HasPrefix(auth, "Basic ") { + encoded := auth[6:] + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + t.Fatal(err) + } + s := string(decoded) + if expected != s { + t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected) + } + } else { + t.Errorf("Invalid auth %q", auth) + } +} + +func TestClientTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + defer afterTest(t) + sawRoot := make(chan bool, 1) + sawSlow := make(chan bool, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/" { + sawRoot <- true + Redirect(w, r, "/slow", StatusFound) + return + } + if r.URL.Path == "/slow" { + w.Write([]byte("Hello")) + w.(Flusher).Flush() + sawSlow <- true + time.Sleep(2 * time.Second) + return + } + })) + defer ts.Close() + const timeout = 500 * time.Millisecond + c := &Client{ + Timeout: timeout, + } + + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + select { + case <-sawRoot: + // good. + default: + t.Fatal("handler never got / request") + } + + select { + case <-sawSlow: + // good. + default: + t.Fatal("handler never got /slow request") + } + + errc := make(chan error, 1) + go func() { + _, err := ioutil.ReadAll(res.Body) + errc <- err + res.Body.Close() + }() + + const failTime = timeout * 2 + select { + case err := <-errc: + if err == nil { + t.Error("expected error from ReadAll") + } + // Expected error. + case <-time.After(failTime): + t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout) + } +} + +func TestClientRedirectEatsBody(t *testing.T) { + defer afterTest(t) + saw := make(chan string, 2) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + saw <- r.RemoteAddr + if r.URL.Path == "/" { + Redirect(w, r, "/foo", StatusFound) // which includes a body + } + })) + defer ts.Close() + + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + _, err = ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + var first string + select { + case first = <-saw: + default: + t.Fatal("server didn't see a request") + } + + var second string + select { + case second = <-saw: + default: + t.Fatal("server didn't see a second request") + } + + if first != second { + t.Fatal("server saw different client ports before & after the redirect") + } +} + +// eofReaderFunc is an io.Reader that runs itself, and then returns io.EOF. +type eofReaderFunc func() + +func (f eofReaderFunc) Read(p []byte) (n int, err error) { + f() + return 0, io.EOF +} + +func TestClientTrailers(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") + w.Header().Add("Trailer", "Server-Trailer-C") + + var decl []string + for k := range r.Trailer { + decl = append(decl, k) + } + sort.Strings(decl) + + slurp, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Errorf("Server reading request body: %v", err) + } + if string(slurp) != "foo" { + t.Errorf("Server read request body %q; want foo", slurp) + } + if r.Trailer == nil { + io.WriteString(w, "nil Trailer") + } else { + fmt.Fprintf(w, "decl: %v, vals: %s, %s", + decl, + r.Trailer.Get("Client-Trailer-A"), + r.Trailer.Get("Client-Trailer-B")) + } + + // TODO: golang.org/issue/7759: there's no way yet for + // the server to set trailers without hijacking, so do + // that for now, just to test the client. Later, in + // Go 1.4, it should be implicit that any mutations + // to w.Header() after the initial write are the + // trailers to be sent, if and only if they were + // previously declared with w.Header().Set("Trailer", + // ..keys..) + w.(Flusher).Flush() + conn, buf, _ := w.(Hijacker).Hijack() + t := Header{} + t.Set("Server-Trailer-A", "valuea") + t.Set("Server-Trailer-C", "valuec") // skipping B + buf.WriteString("0\r\n") // eof + t.Write(buf) + buf.WriteString("\r\n") // end of trailers + buf.Flush() + conn.Close() + })) + defer ts.Close() + + var req *Request + req, _ = NewRequest("POST", ts.URL, io.MultiReader( + eofReaderFunc(func() { + req.Trailer["Client-Trailer-A"] = []string{"valuea"} + }), + strings.NewReader("foo"), + eofReaderFunc(func() { + req.Trailer["Client-Trailer-B"] = []string{"valueb"} + }), + )) + req.Trailer = Header{ + "Client-Trailer-A": nil, // to be set later + "Client-Trailer-B": nil, // to be set later + } + req.ContentLength = -1 + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil { + t.Error(err) + } + want := Header{ + "Server-Trailer-A": []string{"valuea"}, + "Server-Trailer-B": nil, + "Server-Trailer-C": []string{"valuec"}, + } + if !reflect.DeepEqual(res.Trailer, want) { + t.Errorf("Response trailers = %#v; want %#v", res.Trailer, want) + } +} + +func TestReferer(t *testing.T) { + tests := []struct { + lastReq, newReq string // from -> to URLs + want string + }{ + // don't send user: + {"http://gopher@test.com", "http://link.com", "http://test.com"}, + {"https://gopher@test.com", "https://link.com", "https://test.com"}, + + // don't send a user and password: + {"http://gopher:go@test.com", "http://link.com", "http://test.com"}, + {"https://gopher:go@test.com", "https://link.com", "https://test.com"}, + + // nothing to do: + {"http://test.com", "http://link.com", "http://test.com"}, + {"https://test.com", "https://link.com", "https://test.com"}, + + // https to http doesn't send a referer: + {"https://test.com", "http://link.com", ""}, + {"https://gopher:go@test.com", "http://link.com", ""}, + } + for _, tt := range tests { + l, err := url.Parse(tt.lastReq) + if err != nil { + t.Fatal(err) + } + n, err := url.Parse(tt.newReq) + if err != nil { + t.Fatal(err) + } + r := ExportRefererForURL(l, n) + if r != tt.want { + t.Errorf("refererForURL(%q, %q) = %q; want %q", tt.lastReq, tt.newReq, r, tt.want) + } + } +} diff --git a/src/net/http/cookie.go b/src/net/http/cookie.go new file mode 100644 index 000000000..a0d0fdbbd --- /dev/null +++ b/src/net/http/cookie.go @@ -0,0 +1,363 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "fmt" + "log" + "net" + "strconv" + "strings" + "time" +) + +// This implementation is done according to RFC 6265: +// +// http://tools.ietf.org/html/rfc6265 + +// A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an +// HTTP response or the Cookie header of an HTTP request. +type Cookie struct { + Name string + Value string + Path string + Domain string + Expires time.Time + RawExpires string + + // MaxAge=0 means no 'Max-Age' attribute specified. + // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0' + // MaxAge>0 means Max-Age attribute present and given in seconds + MaxAge int + Secure bool + HttpOnly bool + Raw string + Unparsed []string // Raw text of unparsed attribute-value pairs +} + +// readSetCookies parses all "Set-Cookie" values from +// the header h and returns the successfully parsed Cookies. +func readSetCookies(h Header) []*Cookie { + cookies := []*Cookie{} + for _, line := range h["Set-Cookie"] { + parts := strings.Split(strings.TrimSpace(line), ";") + if len(parts) == 1 && parts[0] == "" { + continue + } + parts[0] = strings.TrimSpace(parts[0]) + j := strings.Index(parts[0], "=") + if j < 0 { + continue + } + name, value := parts[0][:j], parts[0][j+1:] + if !isCookieNameValid(name) { + continue + } + value, success := parseCookieValue(value, true) + if !success { + continue + } + c := &Cookie{ + Name: name, + Value: value, + Raw: line, + } + for i := 1; i < len(parts); i++ { + parts[i] = strings.TrimSpace(parts[i]) + if len(parts[i]) == 0 { + continue + } + + attr, val := parts[i], "" + if j := strings.Index(attr, "="); j >= 0 { + attr, val = attr[:j], attr[j+1:] + } + lowerAttr := strings.ToLower(attr) + val, success = parseCookieValue(val, false) + if !success { + c.Unparsed = append(c.Unparsed, parts[i]) + continue + } + switch lowerAttr { + case "secure": + c.Secure = true + continue + case "httponly": + c.HttpOnly = true + continue + case "domain": + c.Domain = val + continue + case "max-age": + secs, err := strconv.Atoi(val) + if err != nil || secs != 0 && val[0] == '0' { + break + } + if secs <= 0 { + c.MaxAge = -1 + } else { + c.MaxAge = secs + } + continue + case "expires": + c.RawExpires = val + exptime, err := time.Parse(time.RFC1123, val) + if err != nil { + exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", val) + if err != nil { + c.Expires = time.Time{} + break + } + } + c.Expires = exptime.UTC() + continue + case "path": + c.Path = val + continue + } + c.Unparsed = append(c.Unparsed, parts[i]) + } + cookies = append(cookies, c) + } + return cookies +} + +// SetCookie adds a Set-Cookie header to the provided ResponseWriter's headers. +func SetCookie(w ResponseWriter, cookie *Cookie) { + w.Header().Add("Set-Cookie", cookie.String()) +} + +// String returns the serialization of the cookie for use in a Cookie +// header (if only Name and Value are set) or a Set-Cookie response +// header (if other fields are set). +func (c *Cookie) String() string { + var b bytes.Buffer + fmt.Fprintf(&b, "%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value)) + if len(c.Path) > 0 { + fmt.Fprintf(&b, "; Path=%s", sanitizeCookiePath(c.Path)) + } + if len(c.Domain) > 0 { + if validCookieDomain(c.Domain) { + // A c.Domain containing illegal characters is not + // sanitized but simply dropped which turns the cookie + // into a host-only cookie. A leading dot is okay + // but won't be sent. + d := c.Domain + if d[0] == '.' { + d = d[1:] + } + fmt.Fprintf(&b, "; Domain=%s", d) + } else { + log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute", + c.Domain) + } + } + if c.Expires.Unix() > 0 { + fmt.Fprintf(&b, "; Expires=%s", c.Expires.UTC().Format(time.RFC1123)) + } + if c.MaxAge > 0 { + fmt.Fprintf(&b, "; Max-Age=%d", c.MaxAge) + } else if c.MaxAge < 0 { + fmt.Fprintf(&b, "; Max-Age=0") + } + if c.HttpOnly { + fmt.Fprintf(&b, "; HttpOnly") + } + if c.Secure { + fmt.Fprintf(&b, "; Secure") + } + return b.String() +} + +// readCookies parses all "Cookie" values from the header h and +// returns the successfully parsed Cookies. +// +// if filter isn't empty, only cookies of that name are returned +func readCookies(h Header, filter string) []*Cookie { + cookies := []*Cookie{} + lines, ok := h["Cookie"] + if !ok { + return cookies + } + + for _, line := range lines { + parts := strings.Split(strings.TrimSpace(line), ";") + if len(parts) == 1 && parts[0] == "" { + continue + } + // Per-line attributes + parsedPairs := 0 + for i := 0; i < len(parts); i++ { + parts[i] = strings.TrimSpace(parts[i]) + if len(parts[i]) == 0 { + continue + } + name, val := parts[i], "" + if j := strings.Index(name, "="); j >= 0 { + name, val = name[:j], name[j+1:] + } + if !isCookieNameValid(name) { + continue + } + if filter != "" && filter != name { + continue + } + val, success := parseCookieValue(val, true) + if !success { + continue + } + cookies = append(cookies, &Cookie{Name: name, Value: val}) + parsedPairs++ + } + } + return cookies +} + +// validCookieDomain returns wheter v is a valid cookie domain-value. +func validCookieDomain(v string) bool { + if isCookieDomainName(v) { + return true + } + if net.ParseIP(v) != nil && !strings.Contains(v, ":") { + return true + } + return false +} + +// isCookieDomainName returns whether s is a valid domain name or a valid +// domain name with a leading dot '.'. It is almost a direct copy of +// package net's isDomainName. +func isCookieDomainName(s string) bool { + if len(s) == 0 { + return false + } + if len(s) > 255 { + return false + } + + if s[0] == '.' { + // A cookie a domain attribute may start with a leading dot. + s = s[1:] + } + last := byte('.') + ok := false // Ok once we've seen a letter. + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z': + // No '_' allowed here (in contrast to package net). + ok = true + partlen++ + case '0' <= c && c <= '9': + // fine + partlen++ + case c == '-': + // Byte before dash cannot be dot. + if last == '.' { + return false + } + partlen++ + case c == '.': + // Byte before dot cannot be dot, dash. + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + if last == '-' || partlen > 63 { + return false + } + + return ok +} + +var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") + +func sanitizeCookieName(n string) string { + return cookieNameSanitizer.Replace(n) +} + +// http://tools.ietf.org/html/rfc6265#section-4.1.1 +// cookie-value = *cookie-octet / ( DQUOTE *cookie-octet DQUOTE ) +// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E +// ; US-ASCII characters excluding CTLs, +// ; whitespace DQUOTE, comma, semicolon, +// ; and backslash +// We loosen this as spaces and commas are common in cookie values +// but we produce a quoted cookie-value in when value starts or ends +// with a comma or space. +// See http://golang.org/issue/7243 for the discussion. +func sanitizeCookieValue(v string) string { + v = sanitizeOrWarn("Cookie.Value", validCookieValueByte, v) + if len(v) == 0 { + return v + } + if v[0] == ' ' || v[0] == ',' || v[len(v)-1] == ' ' || v[len(v)-1] == ',' { + return `"` + v + `"` + } + return v +} + +func validCookieValueByte(b byte) bool { + return 0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\' +} + +// path-av = "Path=" path-value +// path-value = <any CHAR except CTLs or ";"> +func sanitizeCookiePath(v string) string { + return sanitizeOrWarn("Cookie.Path", validCookiePathByte, v) +} + +func validCookiePathByte(b byte) bool { + return 0x20 <= b && b < 0x7f && b != ';' +} + +func sanitizeOrWarn(fieldName string, valid func(byte) bool, v string) string { + ok := true + for i := 0; i < len(v); i++ { + if valid(v[i]) { + continue + } + log.Printf("net/http: invalid byte %q in %s; dropping invalid bytes", v[i], fieldName) + ok = false + break + } + if ok { + return v + } + buf := make([]byte, 0, len(v)) + for i := 0; i < len(v); i++ { + if b := v[i]; valid(b) { + buf = append(buf, b) + } + } + return string(buf) +} + +func parseCookieValue(raw string, allowDoubleQuote bool) (string, bool) { + // Strip the quotes, if present. + if allowDoubleQuote && len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' { + raw = raw[1 : len(raw)-1] + } + for i := 0; i < len(raw); i++ { + if !validCookieValueByte(raw[i]) { + return "", false + } + } + return raw, true +} + +func isCookieNameValid(raw string) bool { + return strings.IndexFunc(raw, isNotToken) < 0 +} diff --git a/src/net/http/cookie_test.go b/src/net/http/cookie_test.go new file mode 100644 index 000000000..98dc2fade --- /dev/null +++ b/src/net/http/cookie_test.go @@ -0,0 +1,412 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "encoding/json" + "fmt" + "log" + "os" + "reflect" + "strings" + "testing" + "time" +) + +var writeSetCookiesTests = []struct { + Cookie *Cookie + Raw string +}{ + { + &Cookie{Name: "cookie-1", Value: "v$1"}, + "cookie-1=v$1", + }, + { + &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600}, + "cookie-2=two; Max-Age=3600", + }, + { + &Cookie{Name: "cookie-3", Value: "three", Domain: ".example.com"}, + "cookie-3=three; Domain=example.com", + }, + { + &Cookie{Name: "cookie-4", Value: "four", Path: "/restricted/"}, + "cookie-4=four; Path=/restricted/", + }, + { + &Cookie{Name: "cookie-5", Value: "five", Domain: "wrong;bad.abc"}, + "cookie-5=five", + }, + { + &Cookie{Name: "cookie-6", Value: "six", Domain: "bad-.abc"}, + "cookie-6=six", + }, + { + &Cookie{Name: "cookie-7", Value: "seven", Domain: "127.0.0.1"}, + "cookie-7=seven; Domain=127.0.0.1", + }, + { + &Cookie{Name: "cookie-8", Value: "eight", Domain: "::1"}, + "cookie-8=eight", + }, + // The "special" cookies have values containing commas or spaces which + // are disallowed by RFC 6265 but are common in the wild. + { + &Cookie{Name: "special-1", Value: "a z"}, + `special-1=a z`, + }, + { + &Cookie{Name: "special-2", Value: " z"}, + `special-2=" z"`, + }, + { + &Cookie{Name: "special-3", Value: "a "}, + `special-3="a "`, + }, + { + &Cookie{Name: "special-4", Value: " "}, + `special-4=" "`, + }, + { + &Cookie{Name: "special-5", Value: "a,z"}, + `special-5=a,z`, + }, + { + &Cookie{Name: "special-6", Value: ",z"}, + `special-6=",z"`, + }, + { + &Cookie{Name: "special-7", Value: "a,"}, + `special-7="a,"`, + }, + { + &Cookie{Name: "special-8", Value: ","}, + `special-8=","`, + }, + { + &Cookie{Name: "empty-value", Value: ""}, + `empty-value=`, + }, +} + +func TestWriteSetCookies(t *testing.T) { + defer log.SetOutput(os.Stderr) + var logbuf bytes.Buffer + log.SetOutput(&logbuf) + + for i, tt := range writeSetCookiesTests { + if g, e := tt.Cookie.String(), tt.Raw; g != e { + t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, e, g) + continue + } + } + + if got, sub := logbuf.String(), "dropping domain attribute"; !strings.Contains(got, sub) { + t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got) + } +} + +type headerOnlyResponseWriter Header + +func (ho headerOnlyResponseWriter) Header() Header { + return Header(ho) +} + +func (ho headerOnlyResponseWriter) Write([]byte) (int, error) { + panic("NOIMPL") +} + +func (ho headerOnlyResponseWriter) WriteHeader(int) { + panic("NOIMPL") +} + +func TestSetCookie(t *testing.T) { + m := make(Header) + SetCookie(headerOnlyResponseWriter(m), &Cookie{Name: "cookie-1", Value: "one", Path: "/restricted/"}) + SetCookie(headerOnlyResponseWriter(m), &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600}) + if l := len(m["Set-Cookie"]); l != 2 { + t.Fatalf("expected %d cookies, got %d", 2, l) + } + if g, e := m["Set-Cookie"][0], "cookie-1=one; Path=/restricted/"; g != e { + t.Errorf("cookie #1: want %q, got %q", e, g) + } + if g, e := m["Set-Cookie"][1], "cookie-2=two; Max-Age=3600"; g != e { + t.Errorf("cookie #2: want %q, got %q", e, g) + } +} + +var addCookieTests = []struct { + Cookies []*Cookie + Raw string +}{ + { + []*Cookie{}, + "", + }, + { + []*Cookie{{Name: "cookie-1", Value: "v$1"}}, + "cookie-1=v$1", + }, + { + []*Cookie{ + {Name: "cookie-1", Value: "v$1"}, + {Name: "cookie-2", Value: "v$2"}, + {Name: "cookie-3", Value: "v$3"}, + }, + "cookie-1=v$1; cookie-2=v$2; cookie-3=v$3", + }, +} + +func TestAddCookie(t *testing.T) { + for i, tt := range addCookieTests { + req, _ := NewRequest("GET", "http://example.com/", nil) + for _, c := range tt.Cookies { + req.AddCookie(c) + } + if g := req.Header.Get("Cookie"); g != tt.Raw { + t.Errorf("Test %d:\nwant: %s\n got: %s\n", i, tt.Raw, g) + continue + } + } +} + +var readSetCookiesTests = []struct { + Header Header + Cookies []*Cookie +}{ + { + Header{"Set-Cookie": {"Cookie-1=v$1"}}, + []*Cookie{{Name: "Cookie-1", Value: "v$1", Raw: "Cookie-1=v$1"}}, + }, + { + Header{"Set-Cookie": {"NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly"}}, + []*Cookie{{ + Name: "NID", + Value: "99=YsDT5i3E-CXax-", + Path: "/", + Domain: ".google.ch", + HttpOnly: true, + Expires: time.Date(2011, 11, 23, 1, 5, 3, 0, time.UTC), + RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT", + Raw: "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly", + }}, + }, + { + Header{"Set-Cookie": {".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly"}}, + []*Cookie{{ + Name: ".ASPXAUTH", + Value: "7E3AA", + Path: "/", + Expires: time.Date(2012, 3, 7, 14, 25, 6, 0, time.UTC), + RawExpires: "Wed, 07-Mar-2012 14:25:06 GMT", + HttpOnly: true, + Raw: ".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly", + }}, + }, + { + Header{"Set-Cookie": {"ASP.NET_SessionId=foo; path=/; HttpOnly"}}, + []*Cookie{{ + Name: "ASP.NET_SessionId", + Value: "foo", + Path: "/", + HttpOnly: true, + Raw: "ASP.NET_SessionId=foo; path=/; HttpOnly", + }}, + }, + // Make sure we can properly read back the Set-Cookie headers we create + // for values containing spaces or commas: + { + Header{"Set-Cookie": {`special-1=a z`}}, + []*Cookie{{Name: "special-1", Value: "a z", Raw: `special-1=a z`}}, + }, + { + Header{"Set-Cookie": {`special-2=" z"`}}, + []*Cookie{{Name: "special-2", Value: " z", Raw: `special-2=" z"`}}, + }, + { + Header{"Set-Cookie": {`special-3="a "`}}, + []*Cookie{{Name: "special-3", Value: "a ", Raw: `special-3="a "`}}, + }, + { + Header{"Set-Cookie": {`special-4=" "`}}, + []*Cookie{{Name: "special-4", Value: " ", Raw: `special-4=" "`}}, + }, + { + Header{"Set-Cookie": {`special-5=a,z`}}, + []*Cookie{{Name: "special-5", Value: "a,z", Raw: `special-5=a,z`}}, + }, + { + Header{"Set-Cookie": {`special-6=",z"`}}, + []*Cookie{{Name: "special-6", Value: ",z", Raw: `special-6=",z"`}}, + }, + { + Header{"Set-Cookie": {`special-7=a,`}}, + []*Cookie{{Name: "special-7", Value: "a,", Raw: `special-7=a,`}}, + }, + { + Header{"Set-Cookie": {`special-8=","`}}, + []*Cookie{{Name: "special-8", Value: ",", Raw: `special-8=","`}}, + }, + + // TODO(bradfitz): users have reported seeing this in the + // wild, but do browsers handle it? RFC 6265 just says "don't + // do that" (section 3) and then never mentions header folding + // again. + // Header{"Set-Cookie": {"ASP.NET_SessionId=foo; path=/; HttpOnly, .ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly"}}, +} + +func toJSON(v interface{}) string { + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%#v", v) + } + return string(b) +} + +func TestReadSetCookies(t *testing.T) { + for i, tt := range readSetCookiesTests { + for n := 0; n < 2; n++ { // to verify readSetCookies doesn't mutate its input + c := readSetCookies(tt.Header) + if !reflect.DeepEqual(c, tt.Cookies) { + t.Errorf("#%d readSetCookies: have\n%s\nwant\n%s\n", i, toJSON(c), toJSON(tt.Cookies)) + continue + } + } + } +} + +var readCookiesTests = []struct { + Header Header + Filter string + Cookies []*Cookie +}{ + { + Header{"Cookie": {"Cookie-1=v$1", "c2=v2"}}, + "", + []*Cookie{ + {Name: "Cookie-1", Value: "v$1"}, + {Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {"Cookie-1=v$1", "c2=v2"}}, + "c2", + []*Cookie{ + {Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {"Cookie-1=v$1; c2=v2"}}, + "", + []*Cookie{ + {Name: "Cookie-1", Value: "v$1"}, + {Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {"Cookie-1=v$1; c2=v2"}}, + "c2", + []*Cookie{ + {Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {`Cookie-1="v$1"; c2="v2"`}}, + "", + []*Cookie{ + {Name: "Cookie-1", Value: "v$1"}, + {Name: "c2", Value: "v2"}, + }, + }, +} + +func TestReadCookies(t *testing.T) { + for i, tt := range readCookiesTests { + for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input + c := readCookies(tt.Header, tt.Filter) + if !reflect.DeepEqual(c, tt.Cookies) { + t.Errorf("#%d readCookies:\nhave: %s\nwant: %s\n", i, toJSON(c), toJSON(tt.Cookies)) + continue + } + } + } +} + +func TestSetCookieDoubleQuotes(t *testing.T) { + res := &Response{Header: Header{}} + res.Header.Add("Set-Cookie", `quoted0=none; max-age=30`) + res.Header.Add("Set-Cookie", `quoted1="cookieValue"; max-age=31`) + res.Header.Add("Set-Cookie", `quoted2=cookieAV; max-age="32"`) + res.Header.Add("Set-Cookie", `quoted3="both"; max-age="33"`) + got := res.Cookies() + want := []*Cookie{ + {Name: "quoted0", Value: "none", MaxAge: 30}, + {Name: "quoted1", Value: "cookieValue", MaxAge: 31}, + {Name: "quoted2", Value: "cookieAV"}, + {Name: "quoted3", Value: "both"}, + } + if len(got) != len(want) { + t.Fatal("got %d cookies, want %d", len(got), len(want)) + } + for i, w := range want { + g := got[i] + if g.Name != w.Name || g.Value != w.Value || g.MaxAge != w.MaxAge { + t.Errorf("cookie #%d:\ngot %v\nwant %v", i, g, w) + } + } +} + +func TestCookieSanitizeValue(t *testing.T) { + defer log.SetOutput(os.Stderr) + var logbuf bytes.Buffer + log.SetOutput(&logbuf) + + tests := []struct { + in, want string + }{ + {"foo", "foo"}, + {"foo;bar", "foobar"}, + {"foo\\bar", "foobar"}, + {"foo\"bar", "foobar"}, + {"\x00\x7e\x7f\x80", "\x7e"}, + {`"withquotes"`, "withquotes"}, + {"a z", "a z"}, + {" z", `" z"`}, + {"a ", `"a "`}, + } + for _, tt := range tests { + if got := sanitizeCookieValue(tt.in); got != tt.want { + t.Errorf("sanitizeCookieValue(%q) = %q; want %q", tt.in, got, tt.want) + } + } + + if got, sub := logbuf.String(), "dropping invalid bytes"; !strings.Contains(got, sub) { + t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got) + } +} + +func TestCookieSanitizePath(t *testing.T) { + defer log.SetOutput(os.Stderr) + var logbuf bytes.Buffer + log.SetOutput(&logbuf) + + tests := []struct { + in, want string + }{ + {"/path", "/path"}, + {"/path with space/", "/path with space/"}, + {"/just;no;semicolon\x00orstuff/", "/justnosemicolonorstuff/"}, + } + for _, tt := range tests { + if got := sanitizeCookiePath(tt.in); got != tt.want { + t.Errorf("sanitizeCookiePath(%q) = %q; want %q", tt.in, got, tt.want) + } + } + + if got, sub := logbuf.String(), "dropping invalid bytes"; !strings.Contains(got, sub) { + t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got) + } +} diff --git a/src/net/http/cookiejar/jar.go b/src/net/http/cookiejar/jar.go new file mode 100644 index 000000000..0e0fac928 --- /dev/null +++ b/src/net/http/cookiejar/jar.go @@ -0,0 +1,497 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package cookiejar implements an in-memory RFC 6265-compliant http.CookieJar. +package cookiejar + +import ( + "errors" + "fmt" + "net" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "time" +) + +// PublicSuffixList provides the public suffix of a domain. For example: +// - the public suffix of "example.com" is "com", +// - the public suffix of "foo1.foo2.foo3.co.uk" is "co.uk", and +// - the public suffix of "bar.pvt.k12.ma.us" is "pvt.k12.ma.us". +// +// Implementations of PublicSuffixList must be safe for concurrent use by +// multiple goroutines. +// +// An implementation that always returns "" is valid and may be useful for +// testing but it is not secure: it means that the HTTP server for foo.com can +// set a cookie for bar.com. +// +// A public suffix list implementation is in the package +// golang.org/x/net/publicsuffix. +type PublicSuffixList interface { + // PublicSuffix returns the public suffix of domain. + // + // TODO: specify which of the caller and callee is responsible for IP + // addresses, for leading and trailing dots, for case sensitivity, and + // for IDN/Punycode. + PublicSuffix(domain string) string + + // String returns a description of the source of this public suffix + // list. The description will typically contain something like a time + // stamp or version number. + String() string +} + +// Options are the options for creating a new Jar. +type Options struct { + // PublicSuffixList is the public suffix list that determines whether + // an HTTP server can set a cookie for a domain. + // + // A nil value is valid and may be useful for testing but it is not + // secure: it means that the HTTP server for foo.co.uk can set a cookie + // for bar.co.uk. + PublicSuffixList PublicSuffixList +} + +// Jar implements the http.CookieJar interface from the net/http package. +type Jar struct { + psList PublicSuffixList + + // mu locks the remaining fields. + mu sync.Mutex + + // entries is a set of entries, keyed by their eTLD+1 and subkeyed by + // their name/domain/path. + entries map[string]map[string]entry + + // nextSeqNum is the next sequence number assigned to a new cookie + // created SetCookies. + nextSeqNum uint64 +} + +// New returns a new cookie jar. A nil *Options is equivalent to a zero +// Options. +func New(o *Options) (*Jar, error) { + jar := &Jar{ + entries: make(map[string]map[string]entry), + } + if o != nil { + jar.psList = o.PublicSuffixList + } + return jar, nil +} + +// entry is the internal representation of a cookie. +// +// This struct type is not used outside of this package per se, but the exported +// fields are those of RFC 6265. +type entry struct { + Name string + Value string + Domain string + Path string + Secure bool + HttpOnly bool + Persistent bool + HostOnly bool + Expires time.Time + Creation time.Time + LastAccess time.Time + + // seqNum is a sequence number so that Cookies returns cookies in a + // deterministic order, even for cookies that have equal Path length and + // equal Creation time. This simplifies testing. + seqNum uint64 +} + +// Id returns the domain;path;name triple of e as an id. +func (e *entry) id() string { + return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name) +} + +// shouldSend determines whether e's cookie qualifies to be included in a +// request to host/path. It is the caller's responsibility to check if the +// cookie is expired. +func (e *entry) shouldSend(https bool, host, path string) bool { + return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure) +} + +// domainMatch implements "domain-match" of RFC 6265 section 5.1.3. +func (e *entry) domainMatch(host string) bool { + if e.Domain == host { + return true + } + return !e.HostOnly && hasDotSuffix(host, e.Domain) +} + +// pathMatch implements "path-match" according to RFC 6265 section 5.1.4. +func (e *entry) pathMatch(requestPath string) bool { + if requestPath == e.Path { + return true + } + if strings.HasPrefix(requestPath, e.Path) { + if e.Path[len(e.Path)-1] == '/' { + return true // The "/any/" matches "/any/path" case. + } else if requestPath[len(e.Path)] == '/' { + return true // The "/any" matches "/any/path" case. + } + } + return false +} + +// hasDotSuffix reports whether s ends in "."+suffix. +func hasDotSuffix(s, suffix string) bool { + return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix +} + +// byPathLength is a []entry sort.Interface that sorts according to RFC 6265 +// section 5.4 point 2: by longest path and then by earliest creation time. +type byPathLength []entry + +func (s byPathLength) Len() int { return len(s) } + +func (s byPathLength) Less(i, j int) bool { + if len(s[i].Path) != len(s[j].Path) { + return len(s[i].Path) > len(s[j].Path) + } + if !s[i].Creation.Equal(s[j].Creation) { + return s[i].Creation.Before(s[j].Creation) + } + return s[i].seqNum < s[j].seqNum +} + +func (s byPathLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// Cookies implements the Cookies method of the http.CookieJar interface. +// +// It returns an empty slice if the URL's scheme is not HTTP or HTTPS. +func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) { + return j.cookies(u, time.Now()) +} + +// cookies is like Cookies but takes the current time as a parameter. +func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { + if u.Scheme != "http" && u.Scheme != "https" { + return cookies + } + host, err := canonicalHost(u.Host) + if err != nil { + return cookies + } + key := jarKey(host, j.psList) + + j.mu.Lock() + defer j.mu.Unlock() + + submap := j.entries[key] + if submap == nil { + return cookies + } + + https := u.Scheme == "https" + path := u.Path + if path == "" { + path = "/" + } + + modified := false + var selected []entry + for id, e := range submap { + if e.Persistent && !e.Expires.After(now) { + delete(submap, id) + modified = true + continue + } + if !e.shouldSend(https, host, path) { + continue + } + e.LastAccess = now + submap[id] = e + selected = append(selected, e) + modified = true + } + if modified { + if len(submap) == 0 { + delete(j.entries, key) + } else { + j.entries[key] = submap + } + } + + sort.Sort(byPathLength(selected)) + for _, e := range selected { + cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value}) + } + + return cookies +} + +// SetCookies implements the SetCookies method of the http.CookieJar interface. +// +// It does nothing if the URL's scheme is not HTTP or HTTPS. +func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) { + j.setCookies(u, cookies, time.Now()) +} + +// setCookies is like SetCookies but takes the current time as parameter. +func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) { + if len(cookies) == 0 { + return + } + if u.Scheme != "http" && u.Scheme != "https" { + return + } + host, err := canonicalHost(u.Host) + if err != nil { + return + } + key := jarKey(host, j.psList) + defPath := defaultPath(u.Path) + + j.mu.Lock() + defer j.mu.Unlock() + + submap := j.entries[key] + + modified := false + for _, cookie := range cookies { + e, remove, err := j.newEntry(cookie, now, defPath, host) + if err != nil { + continue + } + id := e.id() + if remove { + if submap != nil { + if _, ok := submap[id]; ok { + delete(submap, id) + modified = true + } + } + continue + } + if submap == nil { + submap = make(map[string]entry) + } + + if old, ok := submap[id]; ok { + e.Creation = old.Creation + e.seqNum = old.seqNum + } else { + e.Creation = now + e.seqNum = j.nextSeqNum + j.nextSeqNum++ + } + e.LastAccess = now + submap[id] = e + modified = true + } + + if modified { + if len(submap) == 0 { + delete(j.entries, key) + } else { + j.entries[key] = submap + } + } +} + +// canonicalHost strips port from host if present and returns the canonicalized +// host name. +func canonicalHost(host string) (string, error) { + var err error + host = strings.ToLower(host) + if hasPort(host) { + host, _, err = net.SplitHostPort(host) + if err != nil { + return "", err + } + } + if strings.HasSuffix(host, ".") { + // Strip trailing dot from fully qualified domain names. + host = host[:len(host)-1] + } + return toASCII(host) +} + +// hasPort reports whether host contains a port number. host may be a host +// name, an IPv4 or an IPv6 address. +func hasPort(host string) bool { + colons := strings.Count(host, ":") + if colons == 0 { + return false + } + if colons == 1 { + return true + } + return host[0] == '[' && strings.Contains(host, "]:") +} + +// jarKey returns the key to use for a jar. +func jarKey(host string, psl PublicSuffixList) string { + if isIP(host) { + return host + } + + var i int + if psl == nil { + i = strings.LastIndex(host, ".") + if i == -1 { + return host + } + } else { + suffix := psl.PublicSuffix(host) + if suffix == host { + return host + } + i = len(host) - len(suffix) + if i <= 0 || host[i-1] != '.' { + // The provided public suffix list psl is broken. + // Storing cookies under host is a safe stopgap. + return host + } + } + prevDot := strings.LastIndex(host[:i-1], ".") + return host[prevDot+1:] +} + +// isIP reports whether host is an IP address. +func isIP(host string) bool { + return net.ParseIP(host) != nil +} + +// defaultPath returns the directory part of an URL's path according to +// RFC 6265 section 5.1.4. +func defaultPath(path string) string { + if len(path) == 0 || path[0] != '/' { + return "/" // Path is empty or malformed. + } + + i := strings.LastIndex(path, "/") // Path starts with "/", so i != -1. + if i == 0 { + return "/" // Path has the form "/abc". + } + return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/". +} + +// newEntry creates an entry from a http.Cookie c. now is the current time and +// is compared to c.Expires to determine deletion of c. defPath and host are the +// default-path and the canonical host name of the URL c was received from. +// +// remove records whether the jar should delete this cookie, as it has already +// expired with respect to now. In this case, e may be incomplete, but it will +// be valid to call e.id (which depends on e's Name, Domain and Path). +// +// A malformed c.Domain will result in an error. +func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) { + e.Name = c.Name + + if c.Path == "" || c.Path[0] != '/' { + e.Path = defPath + } else { + e.Path = c.Path + } + + e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain) + if err != nil { + return e, false, err + } + + // MaxAge takes precedence over Expires. + if c.MaxAge < 0 { + return e, true, nil + } else if c.MaxAge > 0 { + e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second) + e.Persistent = true + } else { + if c.Expires.IsZero() { + e.Expires = endOfTime + e.Persistent = false + } else { + if !c.Expires.After(now) { + return e, true, nil + } + e.Expires = c.Expires + e.Persistent = true + } + } + + e.Value = c.Value + e.Secure = c.Secure + e.HttpOnly = c.HttpOnly + + return e, false, nil +} + +var ( + errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute") + errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute") + errNoHostname = errors.New("cookiejar: no host name available (IP only)") +) + +// endOfTime is the time when session (non-persistent) cookies expire. +// This instant is representable in most date/time formats (not just +// Go's time.Time) and should be far enough in the future. +var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC) + +// domainAndType determines the cookie's domain and hostOnly attribute. +func (j *Jar) domainAndType(host, domain string) (string, bool, error) { + if domain == "" { + // No domain attribute in the SetCookie header indicates a + // host cookie. + return host, true, nil + } + + if isIP(host) { + // According to RFC 6265 domain-matching includes not being + // an IP address. + // TODO: This might be relaxed as in common browsers. + return "", false, errNoHostname + } + + // From here on: If the cookie is valid, it is a domain cookie (with + // the one exception of a public suffix below). + // See RFC 6265 section 5.2.3. + if domain[0] == '.' { + domain = domain[1:] + } + + if len(domain) == 0 || domain[0] == '.' { + // Received either "Domain=." or "Domain=..some.thing", + // both are illegal. + return "", false, errMalformedDomain + } + domain = strings.ToLower(domain) + + if domain[len(domain)-1] == '.' { + // We received stuff like "Domain=www.example.com.". + // Browsers do handle such stuff (actually differently) but + // RFC 6265 seems to be clear here (e.g. section 4.1.2.3) in + // requiring a reject. 4.1.2.3 is not normative, but + // "Domain Matching" (5.1.3) and "Canonicalized Host Names" + // (5.1.2) are. + return "", false, errMalformedDomain + } + + // See RFC 6265 section 5.3 #5. + if j.psList != nil { + if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) { + if host == domain { + // This is the one exception in which a cookie + // with a domain attribute is a host cookie. + return host, true, nil + } + return "", false, errIllegalDomain + } + } + + // The domain must domain-match host: www.mycompany.com cannot + // set cookies for .ourcompetitors.com. + if host != domain && !hasDotSuffix(host, domain) { + return "", false, errIllegalDomain + } + + return domain, false, nil +} diff --git a/src/net/http/cookiejar/jar_test.go b/src/net/http/cookiejar/jar_test.go new file mode 100644 index 000000000..3aa601586 --- /dev/null +++ b/src/net/http/cookiejar/jar_test.go @@ -0,0 +1,1267 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +import ( + "fmt" + "net/http" + "net/url" + "sort" + "strings" + "testing" + "time" +) + +// tNow is the synthetic current time used as now during testing. +var tNow = time.Date(2013, 1, 1, 12, 0, 0, 0, time.UTC) + +// testPSL implements PublicSuffixList with just two rules: "co.uk" +// and the default rule "*". +type testPSL struct{} + +func (testPSL) String() string { + return "testPSL" +} +func (testPSL) PublicSuffix(d string) string { + if d == "co.uk" || strings.HasSuffix(d, ".co.uk") { + return "co.uk" + } + return d[strings.LastIndex(d, ".")+1:] +} + +// newTestJar creates an empty Jar with testPSL as the public suffix list. +func newTestJar() *Jar { + jar, err := New(&Options{PublicSuffixList: testPSL{}}) + if err != nil { + panic(err) + } + return jar +} + +var hasDotSuffixTests = [...]struct { + s, suffix string +}{ + {"", ""}, + {"", "."}, + {"", "x"}, + {".", ""}, + {".", "."}, + {".", ".."}, + {".", "x"}, + {".", "x."}, + {".", ".x"}, + {".", ".x."}, + {"x", ""}, + {"x", "."}, + {"x", ".."}, + {"x", "x"}, + {"x", "x."}, + {"x", ".x"}, + {"x", ".x."}, + {".x", ""}, + {".x", "."}, + {".x", ".."}, + {".x", "x"}, + {".x", "x."}, + {".x", ".x"}, + {".x", ".x."}, + {"x.", ""}, + {"x.", "."}, + {"x.", ".."}, + {"x.", "x"}, + {"x.", "x."}, + {"x.", ".x"}, + {"x.", ".x."}, + {"com", ""}, + {"com", "m"}, + {"com", "om"}, + {"com", "com"}, + {"com", ".com"}, + {"com", "x.com"}, + {"com", "xcom"}, + {"com", "xorg"}, + {"com", "org"}, + {"com", "rg"}, + {"foo.com", ""}, + {"foo.com", "m"}, + {"foo.com", "om"}, + {"foo.com", "com"}, + {"foo.com", ".com"}, + {"foo.com", "o.com"}, + {"foo.com", "oo.com"}, + {"foo.com", "foo.com"}, + {"foo.com", ".foo.com"}, + {"foo.com", "x.foo.com"}, + {"foo.com", "xfoo.com"}, + {"foo.com", "xfoo.org"}, + {"foo.com", "foo.org"}, + {"foo.com", "oo.org"}, + {"foo.com", "o.org"}, + {"foo.com", ".org"}, + {"foo.com", "org"}, + {"foo.com", "rg"}, +} + +func TestHasDotSuffix(t *testing.T) { + for _, tc := range hasDotSuffixTests { + got := hasDotSuffix(tc.s, tc.suffix) + want := strings.HasSuffix(tc.s, "."+tc.suffix) + if got != want { + t.Errorf("s=%q, suffix=%q: got %v, want %v", tc.s, tc.suffix, got, want) + } + } +} + +var canonicalHostTests = map[string]string{ + "www.example.com": "www.example.com", + "WWW.EXAMPLE.COM": "www.example.com", + "wWw.eXAmple.CoM": "www.example.com", + "www.example.com:80": "www.example.com", + "192.168.0.10": "192.168.0.10", + "192.168.0.5:8080": "192.168.0.5", + "2001:4860:0:2001::68": "2001:4860:0:2001::68", + "[2001:4860:0:::68]:8080": "2001:4860:0:::68", + "www.bรผcher.de": "www.xn--bcher-kva.de", + "www.example.com.": "www.example.com", + "[bad.unmatched.bracket:": "error", +} + +func TestCanonicalHost(t *testing.T) { + for h, want := range canonicalHostTests { + got, err := canonicalHost(h) + if want == "error" { + if err == nil { + t.Errorf("%q: got nil error, want non-nil", h) + } + continue + } + if err != nil { + t.Errorf("%q: %v", h, err) + continue + } + if got != want { + t.Errorf("%q: got %q, want %q", h, got, want) + continue + } + } +} + +var hasPortTests = map[string]bool{ + "www.example.com": false, + "www.example.com:80": true, + "127.0.0.1": false, + "127.0.0.1:8080": true, + "2001:4860:0:2001::68": false, + "[2001::0:::68]:80": true, +} + +func TestHasPort(t *testing.T) { + for host, want := range hasPortTests { + if got := hasPort(host); got != want { + t.Errorf("%q: got %t, want %t", host, got, want) + } + } +} + +var jarKeyTests = map[string]string{ + "foo.www.example.com": "example.com", + "www.example.com": "example.com", + "example.com": "example.com", + "com": "com", + "foo.www.bbc.co.uk": "bbc.co.uk", + "www.bbc.co.uk": "bbc.co.uk", + "bbc.co.uk": "bbc.co.uk", + "co.uk": "co.uk", + "uk": "uk", + "192.168.0.5": "192.168.0.5", +} + +func TestJarKey(t *testing.T) { + for host, want := range jarKeyTests { + if got := jarKey(host, testPSL{}); got != want { + t.Errorf("%q: got %q, want %q", host, got, want) + } + } +} + +var jarKeyNilPSLTests = map[string]string{ + "foo.www.example.com": "example.com", + "www.example.com": "example.com", + "example.com": "example.com", + "com": "com", + "foo.www.bbc.co.uk": "co.uk", + "www.bbc.co.uk": "co.uk", + "bbc.co.uk": "co.uk", + "co.uk": "co.uk", + "uk": "uk", + "192.168.0.5": "192.168.0.5", +} + +func TestJarKeyNilPSL(t *testing.T) { + for host, want := range jarKeyNilPSLTests { + if got := jarKey(host, nil); got != want { + t.Errorf("%q: got %q, want %q", host, got, want) + } + } +} + +var isIPTests = map[string]bool{ + "127.0.0.1": true, + "1.2.3.4": true, + "2001:4860:0:2001::68": true, + "example.com": false, + "1.1.1.300": false, + "www.foo.bar.net": false, + "123.foo.bar.net": false, +} + +func TestIsIP(t *testing.T) { + for host, want := range isIPTests { + if got := isIP(host); got != want { + t.Errorf("%q: got %t, want %t", host, got, want) + } + } +} + +var defaultPathTests = map[string]string{ + "/": "/", + "/abc": "/", + "/abc/": "/abc", + "/abc/xyz": "/abc", + "/abc/xyz/": "/abc/xyz", + "/a/b/c.html": "/a/b", + "": "/", + "strange": "/", + "//": "/", + "/a//b": "/a/", + "/a/./b": "/a/.", + "/a/../b": "/a/..", +} + +func TestDefaultPath(t *testing.T) { + for path, want := range defaultPathTests { + if got := defaultPath(path); got != want { + t.Errorf("%q: got %q, want %q", path, got, want) + } + } +} + +var domainAndTypeTests = [...]struct { + host string // host Set-Cookie header was received from + domain string // domain attribute in Set-Cookie header + wantDomain string // expected domain of cookie + wantHostOnly bool // expected host-cookie flag + wantErr error // expected error +}{ + {"www.example.com", "", "www.example.com", true, nil}, + {"127.0.0.1", "", "127.0.0.1", true, nil}, + {"2001:4860:0:2001::68", "", "2001:4860:0:2001::68", true, nil}, + {"www.example.com", "example.com", "example.com", false, nil}, + {"www.example.com", ".example.com", "example.com", false, nil}, + {"www.example.com", "www.example.com", "www.example.com", false, nil}, + {"www.example.com", ".www.example.com", "www.example.com", false, nil}, + {"foo.sso.example.com", "sso.example.com", "sso.example.com", false, nil}, + {"bar.co.uk", "bar.co.uk", "bar.co.uk", false, nil}, + {"foo.bar.co.uk", ".bar.co.uk", "bar.co.uk", false, nil}, + {"127.0.0.1", "127.0.0.1", "", false, errNoHostname}, + {"2001:4860:0:2001::68", "2001:4860:0:2001::68", "2001:4860:0:2001::68", false, errNoHostname}, + {"www.example.com", ".", "", false, errMalformedDomain}, + {"www.example.com", "..", "", false, errMalformedDomain}, + {"www.example.com", "other.com", "", false, errIllegalDomain}, + {"www.example.com", "com", "", false, errIllegalDomain}, + {"www.example.com", ".com", "", false, errIllegalDomain}, + {"foo.bar.co.uk", ".co.uk", "", false, errIllegalDomain}, + {"127.www.0.0.1", "127.0.0.1", "", false, errIllegalDomain}, + {"com", "", "com", true, nil}, + {"com", "com", "com", true, nil}, + {"com", ".com", "com", true, nil}, + {"co.uk", "", "co.uk", true, nil}, + {"co.uk", "co.uk", "co.uk", true, nil}, + {"co.uk", ".co.uk", "co.uk", true, nil}, +} + +func TestDomainAndType(t *testing.T) { + jar := newTestJar() + for _, tc := range domainAndTypeTests { + domain, hostOnly, err := jar.domainAndType(tc.host, tc.domain) + if err != tc.wantErr { + t.Errorf("%q/%q: got %q error, want %q", + tc.host, tc.domain, err, tc.wantErr) + continue + } + if err != nil { + continue + } + if domain != tc.wantDomain || hostOnly != tc.wantHostOnly { + t.Errorf("%q/%q: got %q/%t want %q/%t", + tc.host, tc.domain, domain, hostOnly, + tc.wantDomain, tc.wantHostOnly) + } + } +} + +// expiresIn creates an expires attribute delta seconds from tNow. +func expiresIn(delta int) string { + t := tNow.Add(time.Duration(delta) * time.Second) + return "expires=" + t.Format(time.RFC1123) +} + +// mustParseURL parses s to an URL and panics on error. +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil || u.Scheme == "" || u.Host == "" { + panic(fmt.Sprintf("Unable to parse URL %s.", s)) + } + return u +} + +// jarTest encapsulates the following actions on a jar: +// 1. Perform SetCookies with fromURL and the cookies from setCookies. +// (Done at time tNow + 0 ms.) +// 2. Check that the entries in the jar matches content. +// (Done at time tNow + 1001 ms.) +// 3. For each query in tests: Check that Cookies with toURL yields the +// cookies in want. +// (Query n done at tNow + (n+2)*1001 ms.) +type jarTest struct { + description string // The description of what this test is supposed to test + fromURL string // The full URL of the request from which Set-Cookie headers where received + setCookies []string // All the cookies received from fromURL + content string // The whole (non-expired) content of the jar + queries []query // Queries to test the Jar.Cookies method +} + +// query contains one test of the cookies returned from Jar.Cookies. +type query struct { + toURL string // the URL in the Cookies call + want string // the expected list of cookies (order matters) +} + +// run runs the jarTest. +func (test jarTest) run(t *testing.T, jar *Jar) { + now := tNow + + // Populate jar with cookies. + setCookies := make([]*http.Cookie, len(test.setCookies)) + for i, cs := range test.setCookies { + cookies := (&http.Response{Header: http.Header{"Set-Cookie": {cs}}}).Cookies() + if len(cookies) != 1 { + panic(fmt.Sprintf("Wrong cookie line %q: %#v", cs, cookies)) + } + setCookies[i] = cookies[0] + } + jar.setCookies(mustParseURL(test.fromURL), setCookies, now) + now = now.Add(1001 * time.Millisecond) + + // Serialize non-expired entries in the form "name1=val1 name2=val2". + var cs []string + for _, submap := range jar.entries { + for _, cookie := range submap { + if !cookie.Expires.After(now) { + continue + } + cs = append(cs, cookie.Name+"="+cookie.Value) + } + } + sort.Strings(cs) + got := strings.Join(cs, " ") + + // Make sure jar content matches our expectations. + if got != test.content { + t.Errorf("Test %q Content\ngot %q\nwant %q", + test.description, got, test.content) + } + + // Test different calls to Cookies. + for i, query := range test.queries { + now = now.Add(1001 * time.Millisecond) + var s []string + for _, c := range jar.cookies(mustParseURL(query.toURL), now) { + s = append(s, c.Name+"="+c.Value) + } + if got := strings.Join(s, " "); got != query.want { + t.Errorf("Test %q #%d\ngot %q\nwant %q", test.description, i, got, query.want) + } + } +} + +// basicsTests contains fundamental tests. Each jarTest has to be performed on +// a fresh, empty Jar. +var basicsTests = [...]jarTest{ + { + "Retrieval of a plain host cookie.", + "http://www.host.test/", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", "A=a"}, + {"http://www.host.test/", "A=a"}, + {"http://www.host.test/some/path", "A=a"}, + {"https://www.host.test", "A=a"}, + {"https://www.host.test/", "A=a"}, + {"https://www.host.test/some/path", "A=a"}, + {"ftp://www.host.test", ""}, + {"ftp://www.host.test/", ""}, + {"ftp://www.host.test/some/path", ""}, + {"http://www.other.org", ""}, + {"http://sibling.host.test", ""}, + {"http://deep.www.host.test", ""}, + }, + }, + { + "Secure cookies are not returned to http.", + "http://www.host.test/", + []string{"A=a; secure"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some/path", ""}, + {"https://www.host.test", "A=a"}, + {"https://www.host.test/", "A=a"}, + {"https://www.host.test/some/path", "A=a"}, + }, + }, + { + "Explicit path.", + "http://www.host.test/", + []string{"A=a; path=/some/path"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #1: path is a directory.", + "http://www.host.test/some/path/", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #2: path is not a directory.", + "http://www.host.test/some/path/index.html", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #3: no path in URL at all.", + "http://www.host.test", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", "A=a"}, + {"http://www.host.test/", "A=a"}, + {"http://www.host.test/some/path", "A=a"}, + }, + }, + { + "Cookies are sorted by path length.", + "http://www.host.test/", + []string{ + "A=a; path=/foo/bar", + "B=b; path=/foo/bar/baz/qux", + "C=c; path=/foo/bar/baz", + "D=d; path=/foo"}, + "A=a B=b C=c D=d", + []query{ + {"http://www.host.test/foo/bar/baz/qux", "B=b C=c A=a D=d"}, + {"http://www.host.test/foo/bar/baz/", "C=c A=a D=d"}, + {"http://www.host.test/foo/bar", "A=a D=d"}, + }, + }, + { + "Creation time determines sorting on same length paths.", + "http://www.host.test/", + []string{ + "A=a; path=/foo/bar", + "X=x; path=/foo/bar", + "Y=y; path=/foo/bar/baz/qux", + "B=b; path=/foo/bar/baz/qux", + "C=c; path=/foo/bar/baz", + "W=w; path=/foo/bar/baz", + "Z=z; path=/foo", + "D=d; path=/foo"}, + "A=a B=b C=c D=d W=w X=x Y=y Z=z", + []query{ + {"http://www.host.test/foo/bar/baz/qux", "Y=y B=b C=c W=w A=a X=x Z=z D=d"}, + {"http://www.host.test/foo/bar/baz/", "C=c W=w A=a X=x Z=z D=d"}, + {"http://www.host.test/foo/bar", "A=a X=x Z=z D=d"}, + }, + }, + { + "Sorting of same-name cookies.", + "http://www.host.test/", + []string{ + "A=1; path=/", + "A=2; path=/path", + "A=3; path=/quux", + "A=4; path=/path/foo", + "A=5; domain=.host.test; path=/path", + "A=6; domain=.host.test; path=/quux", + "A=7; domain=.host.test; path=/path/foo", + }, + "A=1 A=2 A=3 A=4 A=5 A=6 A=7", + []query{ + {"http://www.host.test/path", "A=2 A=5 A=1"}, + {"http://www.host.test/path/foo", "A=4 A=7 A=2 A=5 A=1"}, + }, + }, + { + "Disallow domain cookie on public suffix.", + "http://www.bbc.co.uk", + []string{ + "a=1", + "b=2; domain=co.uk", + }, + "a=1", + []query{{"http://www.bbc.co.uk", "a=1"}}, + }, + { + "Host cookie on IP.", + "http://192.168.0.10", + []string{"a=1"}, + "a=1", + []query{{"http://192.168.0.10", "a=1"}}, + }, + { + "Port is ignored #1.", + "http://www.host.test/", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://www.host.test:8080/", "a=1"}, + }, + }, + { + "Port is ignored #2.", + "http://www.host.test:8080/", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://www.host.test:8080/", "a=1"}, + {"http://www.host.test:1234/", "a=1"}, + }, + }, +} + +func TestBasics(t *testing.T) { + for _, test := range basicsTests { + jar := newTestJar() + test.run(t, jar) + } +} + +// updateAndDeleteTests contains jarTests which must be performed on the same +// Jar. +var updateAndDeleteTests = [...]jarTest{ + { + "Set initial cookies.", + "http://www.host.test", + []string{ + "a=1", + "b=2; secure", + "c=3; httponly", + "d=4; secure; httponly"}, + "a=1 b=2 c=3 d=4", + []query{ + {"http://www.host.test", "a=1 c=3"}, + {"https://www.host.test", "a=1 b=2 c=3 d=4"}, + }, + }, + { + "Update value via http.", + "http://www.host.test", + []string{ + "a=w", + "b=x; secure", + "c=y; httponly", + "d=z; secure; httponly"}, + "a=w b=x c=y d=z", + []query{ + {"http://www.host.test", "a=w c=y"}, + {"https://www.host.test", "a=w b=x c=y d=z"}, + }, + }, + { + "Clear Secure flag from a http.", + "http://www.host.test/", + []string{ + "b=xx", + "d=zz; httponly"}, + "a=w b=xx c=y d=zz", + []query{{"http://www.host.test", "a=w b=xx c=y d=zz"}}, + }, + { + "Delete all.", + "http://www.host.test/", + []string{ + "a=1; max-Age=-1", // delete via MaxAge + "b=2; " + expiresIn(-10), // delete via Expires + "c=2; max-age=-1; " + expiresIn(-10), // delete via both + "d=4; max-age=-1; " + expiresIn(10)}, // MaxAge takes precedence + "", + []query{{"http://www.host.test", ""}}, + }, + { + "Refill #1.", + "http://www.host.test", + []string{ + "A=1", + "A=2; path=/foo", + "A=3; domain=.host.test", + "A=4; path=/foo; domain=.host.test"}, + "A=1 A=2 A=3 A=4", + []query{{"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}}, + }, + { + "Refill #2.", + "http://www.google.com", + []string{ + "A=6", + "A=7; path=/foo", + "A=8; domain=.google.com", + "A=9; path=/foo; domain=.google.com"}, + "A=1 A=2 A=3 A=4 A=6 A=7 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}, + {"http://www.google.com/foo", "A=7 A=9 A=6 A=8"}, + }, + }, + { + "Delete A7.", + "http://www.google.com", + []string{"A=; path=/foo; max-age=-1"}, + "A=1 A=2 A=3 A=4 A=6 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=6 A=8"}, + }, + }, + { + "Delete A4.", + "http://www.host.test", + []string{"A=; path=/foo; domain=host.test; max-age=-1"}, + "A=1 A=2 A=3 A=6 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=6 A=8"}, + }, + }, + { + "Delete A6.", + "http://www.google.com", + []string{"A=; max-age=-1"}, + "A=1 A=2 A=3 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "Delete A3.", + "http://www.host.test", + []string{"A=; domain=host.test; max-age=-1"}, + "A=1 A=2 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "No cross-domain delete.", + "http://www.host.test", + []string{ + "A=; domain=google.com; max-age=-1", + "A=; path=/foo; domain=google.com; max-age=-1"}, + "A=1 A=2 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "Delete A8 and A9.", + "http://www.google.com", + []string{ + "A=; domain=google.com; max-age=-1", + "A=; path=/foo; domain=google.com; max-age=-1"}, + "A=1 A=2", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", ""}, + }, + }, +} + +func TestUpdateAndDelete(t *testing.T) { + jar := newTestJar() + for _, test := range updateAndDeleteTests { + test.run(t, jar) + } +} + +func TestExpiration(t *testing.T) { + jar := newTestJar() + jarTest{ + "Expiration.", + "http://www.host.test", + []string{ + "a=1", + "b=2; max-age=3", + "c=3; " + expiresIn(3), + "d=4; max-age=5", + "e=5; " + expiresIn(5), + "f=6; max-age=100", + }, + "a=1 b=2 c=3 d=4 e=5 f=6", // executed at t0 + 1001 ms + []query{ + {"http://www.host.test", "a=1 b=2 c=3 d=4 e=5 f=6"}, // t0 + 2002 ms + {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 3003 ms + {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 4004 ms + {"http://www.host.test", "a=1 f=6"}, // t0 + 5005 ms + {"http://www.host.test", "a=1 f=6"}, // t0 + 6006 ms + }, + }.run(t, jar) +} + +// +// Tests derived from Chromium's cookie_store_unittest.h. +// + +// See http://src.chromium.org/viewvc/chrome/trunk/src/net/cookies/cookie_store_unittest.h?revision=159685&content-type=text/plain +// Some of the original tests are in a bad condition (e.g. +// DomainWithTrailingDotTest) or are not RFC 6265 conforming (e.g. +// TestNonDottedAndTLD #1 and #6) and have not been ported. + +// chromiumBasicsTests contains fundamental tests. Each jarTest has to be +// performed on a fresh, empty Jar. +var chromiumBasicsTests = [...]jarTest{ + { + "DomainWithTrailingDotTest.", + "http://www.google.com/", + []string{ + "a=1; domain=.www.google.com.", + "b=2; domain=.www.google.com.."}, + "", + []query{ + {"http://www.google.com", ""}, + }, + }, + { + "ValidSubdomainTest #1.", + "http://a.b.c.d.com", + []string{ + "a=1; domain=.a.b.c.d.com", + "b=2; domain=.b.c.d.com", + "c=3; domain=.c.d.com", + "d=4; domain=.d.com"}, + "a=1 b=2 c=3 d=4", + []query{ + {"http://a.b.c.d.com", "a=1 b=2 c=3 d=4"}, + {"http://b.c.d.com", "b=2 c=3 d=4"}, + {"http://c.d.com", "c=3 d=4"}, + {"http://d.com", "d=4"}, + }, + }, + { + "ValidSubdomainTest #2.", + "http://a.b.c.d.com", + []string{ + "a=1; domain=.a.b.c.d.com", + "b=2; domain=.b.c.d.com", + "c=3; domain=.c.d.com", + "d=4; domain=.d.com", + "X=bcd; domain=.b.c.d.com", + "X=cd; domain=.c.d.com"}, + "X=bcd X=cd a=1 b=2 c=3 d=4", + []query{ + {"http://b.c.d.com", "b=2 c=3 d=4 X=bcd X=cd"}, + {"http://c.d.com", "c=3 d=4 X=cd"}, + }, + }, + { + "InvalidDomainTest #1.", + "http://foo.bar.com", + []string{ + "a=1; domain=.yo.foo.bar.com", + "b=2; domain=.foo.com", + "c=3; domain=.bar.foo.com", + "d=4; domain=.foo.bar.com.net", + "e=5; domain=ar.com", + "f=6; domain=.", + "g=7; domain=/", + "h=8; domain=http://foo.bar.com", + "i=9; domain=..foo.bar.com", + "j=10; domain=..bar.com", + "k=11; domain=.foo.bar.com?blah", + "l=12; domain=.foo.bar.com/blah", + "m=12; domain=.foo.bar.com:80", + "n=14; domain=.foo.bar.com:", + "o=15; domain=.foo.bar.com#sup", + }, + "", // Jar is empty. + []query{{"http://foo.bar.com", ""}}, + }, + { + "InvalidDomainTest #2.", + "http://foo.com.com", + []string{"a=1; domain=.foo.com.com.com"}, + "", + []query{{"http://foo.bar.com", ""}}, + }, + { + "DomainWithoutLeadingDotTest #1.", + "http://manage.hosted.filefront.com", + []string{"a=1; domain=filefront.com"}, + "a=1", + []query{{"http://www.filefront.com", "a=1"}}, + }, + { + "DomainWithoutLeadingDotTest #2.", + "http://www.google.com", + []string{"a=1; domain=www.google.com"}, + "a=1", + []query{ + {"http://www.google.com", "a=1"}, + {"http://sub.www.google.com", "a=1"}, + {"http://something-else.com", ""}, + }, + }, + { + "CaseInsensitiveDomainTest.", + "http://www.google.com", + []string{ + "a=1; domain=.GOOGLE.COM", + "b=2; domain=.www.gOOgLE.coM"}, + "a=1 b=2", + []query{{"http://www.google.com", "a=1 b=2"}}, + }, + { + "TestIpAddress #1.", + "http://1.2.3.4/foo", + []string{"a=1; path=/"}, + "a=1", + []query{{"http://1.2.3.4/foo", "a=1"}}, + }, + { + "TestIpAddress #2.", + "http://1.2.3.4/foo", + []string{ + "a=1; domain=.1.2.3.4", + "b=2; domain=.3.4"}, + "", + []query{{"http://1.2.3.4/foo", ""}}, + }, + { + "TestIpAddress #3.", + "http://1.2.3.4/foo", + []string{"a=1; domain=1.2.3.4"}, + "", + []query{{"http://1.2.3.4/foo", ""}}, + }, + { + "TestNonDottedAndTLD #2.", + "http://com./index.html", + []string{"a=1"}, + "a=1", + []query{ + {"http://com./index.html", "a=1"}, + {"http://no-cookies.com./index.html", ""}, + }, + }, + { + "TestNonDottedAndTLD #3.", + "http://a.b", + []string{ + "a=1; domain=.b", + "b=2; domain=b"}, + "", + []query{{"http://bar.foo", ""}}, + }, + { + "TestNonDottedAndTLD #4.", + "http://google.com", + []string{ + "a=1; domain=.com", + "b=2; domain=com"}, + "", + []query{{"http://google.com", ""}}, + }, + { + "TestNonDottedAndTLD #5.", + "http://google.co.uk", + []string{ + "a=1; domain=.co.uk", + "b=2; domain=.uk"}, + "", + []query{ + {"http://google.co.uk", ""}, + {"http://else.co.com", ""}, + {"http://else.uk", ""}, + }, + }, + { + "TestHostEndsWithDot.", + "http://www.google.com", + []string{ + "a=1", + "b=2; domain=.www.google.com."}, + "a=1", + []query{{"http://www.google.com", "a=1"}}, + }, + { + "PathTest", + "http://www.google.izzle", + []string{"a=1; path=/wee"}, + "a=1", + []query{ + {"http://www.google.izzle/wee", "a=1"}, + {"http://www.google.izzle/wee/", "a=1"}, + {"http://www.google.izzle/wee/war", "a=1"}, + {"http://www.google.izzle/wee/war/more/more", "a=1"}, + {"http://www.google.izzle/weehee", ""}, + {"http://www.google.izzle/", ""}, + }, + }, +} + +func TestChromiumBasics(t *testing.T) { + for _, test := range chromiumBasicsTests { + jar := newTestJar() + test.run(t, jar) + } +} + +// chromiumDomainTests contains jarTests which must be executed all on the +// same Jar. +var chromiumDomainTests = [...]jarTest{ + { + "Fill #1.", + "http://www.google.izzle", + []string{"A=B"}, + "A=B", + []query{{"http://www.google.izzle", "A=B"}}, + }, + { + "Fill #2.", + "http://www.google.izzle", + []string{"C=D; domain=.google.izzle"}, + "A=B C=D", + []query{{"http://www.google.izzle", "A=B C=D"}}, + }, + { + "Verify A is a host cookie and not accessible from subdomain.", + "http://unused.nil", + []string{}, + "A=B C=D", + []query{{"http://foo.www.google.izzle", "C=D"}}, + }, + { + "Verify domain cookies are found on proper domain.", + "http://www.google.izzle", + []string{"E=F; domain=.www.google.izzle"}, + "A=B C=D E=F", + []query{{"http://www.google.izzle", "A=B C=D E=F"}}, + }, + { + "Leading dots in domain attributes are optional.", + "http://www.google.izzle", + []string{"G=H; domain=www.google.izzle"}, + "A=B C=D E=F G=H", + []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}}, + }, + { + "Verify domain enforcement works #1.", + "http://www.google.izzle", + []string{"K=L; domain=.bar.www.google.izzle"}, + "A=B C=D E=F G=H", + []query{{"http://bar.www.google.izzle", "C=D E=F G=H"}}, + }, + { + "Verify domain enforcement works #2.", + "http://unused.nil", + []string{}, + "A=B C=D E=F G=H", + []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}}, + }, +} + +func TestChromiumDomain(t *testing.T) { + jar := newTestJar() + for _, test := range chromiumDomainTests { + test.run(t, jar) + } + +} + +// chromiumDeletionTests must be performed all on the same Jar. +var chromiumDeletionTests = [...]jarTest{ + { + "Create session cookie a1.", + "http://www.google.com", + []string{"a=1"}, + "a=1", + []query{{"http://www.google.com", "a=1"}}, + }, + { + "Delete sc a1 via MaxAge.", + "http://www.google.com", + []string{"a=1; max-age=-1"}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create session cookie b2.", + "http://www.google.com", + []string{"b=2"}, + "b=2", + []query{{"http://www.google.com", "b=2"}}, + }, + { + "Delete sc b2 via Expires.", + "http://www.google.com", + []string{"b=2; " + expiresIn(-10)}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create persistent cookie c3.", + "http://www.google.com", + []string{"c=3; max-age=3600"}, + "c=3", + []query{{"http://www.google.com", "c=3"}}, + }, + { + "Delete pc c3 via MaxAge.", + "http://www.google.com", + []string{"c=3; max-age=-1"}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create persistent cookie d4.", + "http://www.google.com", + []string{"d=4; max-age=3600"}, + "d=4", + []query{{"http://www.google.com", "d=4"}}, + }, + { + "Delete pc d4 via Expires.", + "http://www.google.com", + []string{"d=4; " + expiresIn(-10)}, + "", + []query{{"http://www.google.com", ""}}, + }, +} + +func TestChromiumDeletion(t *testing.T) { + jar := newTestJar() + for _, test := range chromiumDeletionTests { + test.run(t, jar) + } +} + +// domainHandlingTests tests and documents the rules for domain handling. +// Each test must be performed on an empty new Jar. +var domainHandlingTests = [...]jarTest{ + { + "Host cookie", + "http://www.host.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", ""}, + {"http://bar.host.test", ""}, + {"http://foo.www.host.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie #1", + "http://www.host.test", + []string{"a=1; domain=host.test"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", "a=1"}, + {"http://bar.host.test", "a=1"}, + {"http://foo.www.host.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie #2", + "http://www.host.test", + []string{"a=1; domain=.host.test"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", "a=1"}, + {"http://bar.host.test", "a=1"}, + {"http://foo.www.host.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on IDNA domain #1", + "http://www.bรผcher.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.bรผcher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bรผcher.test", ""}, + {"http://xn--bcher-kva.test", ""}, + {"http://bar.bรผcher.test", ""}, + {"http://bar.xn--bcher-kva.test", ""}, + {"http://foo.www.bรผcher.test", ""}, + {"http://foo.www.xn--bcher-kva.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on IDNA domain #2", + "http://www.xn--bcher-kva.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.bรผcher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bรผcher.test", ""}, + {"http://xn--bcher-kva.test", ""}, + {"http://bar.bรผcher.test", ""}, + {"http://bar.xn--bcher-kva.test", ""}, + {"http://foo.www.bรผcher.test", ""}, + {"http://foo.www.xn--bcher-kva.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie on IDNA domain #1", + "http://www.bรผcher.test", + []string{"a=1; domain=xn--bcher-kva.test"}, + "a=1", + []query{ + {"http://www.bรผcher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bรผcher.test", "a=1"}, + {"http://xn--bcher-kva.test", "a=1"}, + {"http://bar.bรผcher.test", "a=1"}, + {"http://bar.xn--bcher-kva.test", "a=1"}, + {"http://foo.www.bรผcher.test", "a=1"}, + {"http://foo.www.xn--bcher-kva.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie on IDNA domain #2", + "http://www.xn--bcher-kva.test", + []string{"a=1; domain=xn--bcher-kva.test"}, + "a=1", + []query{ + {"http://www.bรผcher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bรผcher.test", "a=1"}, + {"http://xn--bcher-kva.test", "a=1"}, + {"http://bar.bรผcher.test", "a=1"}, + {"http://bar.xn--bcher-kva.test", "a=1"}, + {"http://foo.www.bรผcher.test", "a=1"}, + {"http://foo.www.xn--bcher-kva.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on TLD.", + "http://com", + []string{"a=1"}, + "a=1", + []query{ + {"http://com", "a=1"}, + {"http://any.com", ""}, + {"http://any.test", ""}, + }, + }, + { + "Domain cookie on TLD becomes a host cookie.", + "http://com", + []string{"a=1; domain=com"}, + "a=1", + []query{ + {"http://com", "a=1"}, + {"http://any.com", ""}, + {"http://any.test", ""}, + }, + }, + { + "Host cookie on public suffix.", + "http://co.uk", + []string{"a=1"}, + "a=1", + []query{ + {"http://co.uk", "a=1"}, + {"http://uk", ""}, + {"http://some.co.uk", ""}, + {"http://foo.some.co.uk", ""}, + {"http://any.uk", ""}, + }, + }, + { + "Domain cookie on public suffix is ignored.", + "http://some.co.uk", + []string{"a=1; domain=co.uk"}, + "", + []query{ + {"http://co.uk", ""}, + {"http://uk", ""}, + {"http://some.co.uk", ""}, + {"http://foo.some.co.uk", ""}, + {"http://any.uk", ""}, + }, + }, +} + +func TestDomainHandling(t *testing.T) { + for _, test := range domainHandlingTests { + jar := newTestJar() + test.run(t, jar) + } +} diff --git a/src/net/http/cookiejar/punycode.go b/src/net/http/cookiejar/punycode.go new file mode 100644 index 000000000..ea7ceb5ef --- /dev/null +++ b/src/net/http/cookiejar/punycode.go @@ -0,0 +1,159 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +// This file implements the Punycode algorithm from RFC 3492. + +import ( + "fmt" + "strings" + "unicode/utf8" +) + +// These parameter values are specified in section 5. +// +// All computation is done with int32s, so that overflow behavior is identical +// regardless of whether int is 32-bit or 64-bit. +const ( + base int32 = 36 + damp int32 = 700 + initialBias int32 = 72 + initialN int32 = 128 + skew int32 = 38 + tmax int32 = 26 + tmin int32 = 1 +) + +// encode encodes a string as specified in section 6.3 and prepends prefix to +// the result. +// +// The "while h < length(input)" line in the specification becomes "for +// remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes. +func encode(prefix, s string) (string, error) { + output := make([]byte, len(prefix), len(prefix)+1+2*len(s)) + copy(output, prefix) + delta, n, bias := int32(0), initialN, initialBias + b, remaining := int32(0), int32(0) + for _, r := range s { + if r < 0x80 { + b++ + output = append(output, byte(r)) + } else { + remaining++ + } + } + h := b + if b > 0 { + output = append(output, '-') + } + for remaining != 0 { + m := int32(0x7fffffff) + for _, r := range s { + if m > r && r >= n { + m = r + } + } + delta += (m - n) * (h + 1) + if delta < 0 { + return "", fmt.Errorf("cookiejar: invalid label %q", s) + } + n = m + for _, r := range s { + if r < n { + delta++ + if delta < 0 { + return "", fmt.Errorf("cookiejar: invalid label %q", s) + } + continue + } + if r > n { + continue + } + q := delta + for k := base; ; k += base { + t := k - bias + if t < tmin { + t = tmin + } else if t > tmax { + t = tmax + } + if q < t { + break + } + output = append(output, encodeDigit(t+(q-t)%(base-t))) + q = (q - t) / (base - t) + } + output = append(output, encodeDigit(q)) + bias = adapt(delta, h+1, h == b) + delta = 0 + h++ + remaining-- + } + delta++ + n++ + } + return string(output), nil +} + +func encodeDigit(digit int32) byte { + switch { + case 0 <= digit && digit < 26: + return byte(digit + 'a') + case 26 <= digit && digit < 36: + return byte(digit + ('0' - 26)) + } + panic("cookiejar: internal error in punycode encoding") +} + +// adapt is the bias adaptation function specified in section 6.1. +func adapt(delta, numPoints int32, firstTime bool) int32 { + if firstTime { + delta /= damp + } else { + delta /= 2 + } + delta += delta / numPoints + k := int32(0) + for delta > ((base-tmin)*tmax)/2 { + delta /= base - tmin + k += base + } + return k + (base-tmin+1)*delta/(delta+skew) +} + +// Strictly speaking, the remaining code below deals with IDNA (RFC 5890 and +// friends) and not Punycode (RFC 3492) per se. + +// acePrefix is the ASCII Compatible Encoding prefix. +const acePrefix = "xn--" + +// toASCII converts a domain or domain label to its ASCII form. For example, +// toASCII("bรผcher.example.com") is "xn--bcher-kva.example.com", and +// toASCII("golang") is "golang". +func toASCII(s string) (string, error) { + if ascii(s) { + return s, nil + } + labels := strings.Split(s, ".") + for i, label := range labels { + if !ascii(label) { + a, err := encode(acePrefix, label) + if err != nil { + return "", err + } + labels[i] = a + } + } + return strings.Join(labels, "."), nil +} + +func ascii(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + return false + } + } + return true +} diff --git a/src/net/http/cookiejar/punycode_test.go b/src/net/http/cookiejar/punycode_test.go new file mode 100644 index 000000000..0301de14e --- /dev/null +++ b/src/net/http/cookiejar/punycode_test.go @@ -0,0 +1,161 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +import ( + "testing" +) + +var punycodeTestCases = [...]struct { + s, encoded string +}{ + {"", ""}, + {"-", "--"}, + {"-a", "-a-"}, + {"-a-", "-a--"}, + {"a", "a-"}, + {"a-", "a--"}, + {"a-b", "a-b-"}, + {"books", "books-"}, + {"bรผcher", "bcher-kva"}, + {"Helloไธ็", "Hello-ck1hg65u"}, + {"รผ", "tda"}, + {"รผรฝ", "tdac"}, + + // The test cases below come from RFC 3492 section 7.1 with Errata 3026. + { + // (A) Arabic (Egyptian). + "\u0644\u064A\u0647\u0645\u0627\u0628\u062A\u0643\u0644" + + "\u0645\u0648\u0634\u0639\u0631\u0628\u064A\u061F", + "egbpdaj6bu4bxfgehfvwxn", + }, + { + // (B) Chinese (simplified). + "\u4ED6\u4EEC\u4E3A\u4EC0\u4E48\u4E0D\u8BF4\u4E2D\u6587", + "ihqwcrb4cv8a8dqg056pqjye", + }, + { + // (C) Chinese (traditional). + "\u4ED6\u5011\u7232\u4EC0\u9EBD\u4E0D\u8AAA\u4E2D\u6587", + "ihqwctvzc91f659drss3x8bo0yb", + }, + { + // (D) Czech. + "\u0050\u0072\u006F\u010D\u0070\u0072\u006F\u0073\u0074" + + "\u011B\u006E\u0065\u006D\u006C\u0075\u0076\u00ED\u010D" + + "\u0065\u0073\u006B\u0079", + "Proprostnemluvesky-uyb24dma41a", + }, + { + // (E) Hebrew. + "\u05DC\u05DE\u05D4\u05D4\u05DD\u05E4\u05E9\u05D5\u05D8" + + "\u05DC\u05D0\u05DE\u05D3\u05D1\u05E8\u05D9\u05DD\u05E2" + + "\u05D1\u05E8\u05D9\u05EA", + "4dbcagdahymbxekheh6e0a7fei0b", + }, + { + // (F) Hindi (Devanagari). + "\u092F\u0939\u0932\u094B\u0917\u0939\u093F\u0928\u094D" + + "\u0926\u0940\u0915\u094D\u092F\u094B\u0902\u0928\u0939" + + "\u0940\u0902\u092C\u094B\u0932\u0938\u0915\u0924\u0947" + + "\u0939\u0948\u0902", + "i1baa7eci9glrd9b2ae1bj0hfcgg6iyaf8o0a1dig0cd", + }, + { + // (G) Japanese (kanji and hiragana). + "\u306A\u305C\u307F\u3093\u306A\u65E5\u672C\u8A9E\u3092" + + "\u8A71\u3057\u3066\u304F\u308C\u306A\u3044\u306E\u304B", + "n8jok5ay5dzabd5bym9f0cm5685rrjetr6pdxa", + }, + { + // (H) Korean (Hangul syllables). + "\uC138\uACC4\uC758\uBAA8\uB4E0\uC0AC\uB78C\uB4E4\uC774" + + "\uD55C\uAD6D\uC5B4\uB97C\uC774\uD574\uD55C\uB2E4\uBA74" + + "\uC5BC\uB9C8\uB098\uC88B\uC744\uAE4C", + "989aomsvi5e83db1d2a355cv1e0vak1dwrv93d5xbh15a0dt30a5j" + + "psd879ccm6fea98c", + }, + { + // (I) Russian (Cyrillic). + "\u043F\u043E\u0447\u0435\u043C\u0443\u0436\u0435\u043E" + + "\u043D\u0438\u043D\u0435\u0433\u043E\u0432\u043E\u0440" + + "\u044F\u0442\u043F\u043E\u0440\u0443\u0441\u0441\u043A" + + "\u0438", + "b1abfaaepdrnnbgefbadotcwatmq2g4l", + }, + { + // (J) Spanish. + "\u0050\u006F\u0072\u0071\u0075\u00E9\u006E\u006F\u0070" + + "\u0075\u0065\u0064\u0065\u006E\u0073\u0069\u006D\u0070" + + "\u006C\u0065\u006D\u0065\u006E\u0074\u0065\u0068\u0061" + + "\u0062\u006C\u0061\u0072\u0065\u006E\u0045\u0073\u0070" + + "\u0061\u00F1\u006F\u006C", + "PorqunopuedensimplementehablarenEspaol-fmd56a", + }, + { + // (K) Vietnamese. + "\u0054\u1EA1\u0069\u0073\u0061\u006F\u0068\u1ECD\u006B" + + "\u0068\u00F4\u006E\u0067\u0074\u0068\u1EC3\u0063\u0068" + + "\u1EC9\u006E\u00F3\u0069\u0074\u0069\u1EBF\u006E\u0067" + + "\u0056\u0069\u1EC7\u0074", + "TisaohkhngthchnitingVit-kjcr8268qyxafd2f1b9g", + }, + { + // (L) 3<nen>B<gumi><kinpachi><sensei>. + "\u0033\u5E74\u0042\u7D44\u91D1\u516B\u5148\u751F", + "3B-ww4c5e180e575a65lsy2b", + }, + { + // (M) <amuro><namie>-with-SUPER-MONKEYS. + "\u5B89\u5BA4\u5948\u7F8E\u6075\u002D\u0077\u0069\u0074" + + "\u0068\u002D\u0053\u0055\u0050\u0045\u0052\u002D\u004D" + + "\u004F\u004E\u004B\u0045\u0059\u0053", + "-with-SUPER-MONKEYS-pc58ag80a8qai00g7n9n", + }, + { + // (N) Hello-Another-Way-<sorezore><no><basho>. + "\u0048\u0065\u006C\u006C\u006F\u002D\u0041\u006E\u006F" + + "\u0074\u0068\u0065\u0072\u002D\u0057\u0061\u0079\u002D" + + "\u305D\u308C\u305E\u308C\u306E\u5834\u6240", + "Hello-Another-Way--fc4qua05auwb3674vfr0b", + }, + { + // (O) <hitotsu><yane><no><shita>2. + "\u3072\u3068\u3064\u5C4B\u6839\u306E\u4E0B\u0032", + "2-u9tlzr9756bt3uc0v", + }, + { + // (P) Maji<de>Koi<suru>5<byou><mae> + "\u004D\u0061\u006A\u0069\u3067\u004B\u006F\u0069\u3059" + + "\u308B\u0035\u79D2\u524D", + "MajiKoi5-783gue6qz075azm5e", + }, + { + // (Q) <pafii>de<runba> + "\u30D1\u30D5\u30A3\u30FC\u0064\u0065\u30EB\u30F3\u30D0", + "de-jg4avhby1noc0d", + }, + { + // (R) <sono><supiido><de> + "\u305D\u306E\u30B9\u30D4\u30FC\u30C9\u3067", + "d9juau41awczczp", + }, + { + // (S) -> $1.00 <- + "\u002D\u003E\u0020\u0024\u0031\u002E\u0030\u0030\u0020" + + "\u003C\u002D", + "-> $1.00 <--", + }, +} + +func TestPunycode(t *testing.T) { + for _, tc := range punycodeTestCases { + if got, err := encode("", tc.s); err != nil { + t.Errorf(`encode("", %q): %v`, tc.s, err) + } else if got != tc.encoded { + t.Errorf(`encode("", %q): got %q, want %q`, tc.s, got, tc.encoded) + } + } +} diff --git a/src/net/http/doc.go b/src/net/http/doc.go new file mode 100644 index 000000000..b1216e8da --- /dev/null +++ b/src/net/http/doc.go @@ -0,0 +1,80 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package http provides HTTP client and server implementations. + +Get, Head, Post, and PostForm make HTTP (or HTTPS) requests: + + resp, err := http.Get("http://example.com/") + ... + resp, err := http.Post("http://example.com/upload", "image/jpeg", &buf) + ... + resp, err := http.PostForm("http://example.com/form", + url.Values{"key": {"Value"}, "id": {"123"}}) + +The client must close the response body when finished with it: + + resp, err := http.Get("http://example.com/") + if err != nil { + // handle error + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + // ... + +For control over HTTP client headers, redirect policy, and other +settings, create a Client: + + client := &http.Client{ + CheckRedirect: redirectPolicyFunc, + } + + resp, err := client.Get("http://example.com") + // ... + + req, err := http.NewRequest("GET", "http://example.com", nil) + // ... + req.Header.Add("If-None-Match", `W/"wyzzy"`) + resp, err := client.Do(req) + // ... + +For control over proxies, TLS configuration, keep-alives, +compression, and other settings, create a Transport: + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{RootCAs: pool}, + DisableCompression: true, + } + client := &http.Client{Transport: tr} + resp, err := client.Get("https://example.com") + +Clients and Transports are safe for concurrent use by multiple +goroutines and for efficiency should only be created once and re-used. + +ListenAndServe starts an HTTP server with a given address and handler. +The handler is usually nil, which means to use DefaultServeMux. +Handle and HandleFunc add handlers to DefaultServeMux: + + http.Handle("/foo", fooHandler) + + http.HandleFunc("/bar", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) + }) + + log.Fatal(http.ListenAndServe(":8080", nil)) + +More control over the server's behavior is available by creating a +custom Server: + + s := &http.Server{ + Addr: ":8080", + Handler: myHandler, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + MaxHeaderBytes: 1 << 20, + } + log.Fatal(s.ListenAndServe()) +*/ +package http diff --git a/src/net/http/example_test.go b/src/net/http/example_test.go new file mode 100644 index 000000000..88b97d9e3 --- /dev/null +++ b/src/net/http/example_test.go @@ -0,0 +1,88 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "fmt" + "io/ioutil" + "log" + "net/http" +) + +func ExampleHijacker() { + http.HandleFunc("/hijack", func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError) + return + } + conn, bufrw, err := hj.Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + // Don't forget to close the connection: + defer conn.Close() + bufrw.WriteString("Now we're speaking raw TCP. Say hi: ") + bufrw.Flush() + s, err := bufrw.ReadString('\n') + if err != nil { + log.Printf("error reading string: %v", err) + return + } + fmt.Fprintf(bufrw, "You said: %q\nBye.\n", s) + bufrw.Flush() + }) +} + +func ExampleGet() { + res, err := http.Get("http://www.google.com/robots.txt") + if err != nil { + log.Fatal(err) + } + robots, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Fatal(err) + } + fmt.Printf("%s", robots) +} + +func ExampleFileServer() { + // Simple static webserver: + log.Fatal(http.ListenAndServe(":8080", http.FileServer(http.Dir("/usr/share/doc")))) +} + +func ExampleFileServer_stripPrefix() { + // To serve a directory on disk (/tmp) under an alternate URL + // path (/tmpfiles/), use StripPrefix to modify the request + // URL's path before the FileServer sees it: + http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp")))) +} + +func ExampleStripPrefix() { + // To serve a directory on disk (/tmp) under an alternate URL + // path (/tmpfiles/), use StripPrefix to modify the request + // URL's path before the FileServer sees it: + http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp")))) +} + +type apiHandler struct{} + +func (apiHandler) ServeHTTP(http.ResponseWriter, *http.Request) {} + +func ExampleServeMux_Handle() { + mux := http.NewServeMux() + mux.Handle("/api/", apiHandler{}) + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + // The "/" pattern matches everything, so we need to check + // that we're at the root here. + if req.URL.Path != "/" { + http.NotFound(w, req) + return + } + fmt.Fprintf(w, "Welcome to the home page!") + }) +} diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go new file mode 100644 index 000000000..87b6c0773 --- /dev/null +++ b/src/net/http/export_test.go @@ -0,0 +1,108 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Bridge package to expose http internals to tests in the http_test +// package. + +package http + +import ( + "net" + "net/url" + "time" +) + +func NewLoggingConn(baseName string, c net.Conn) net.Conn { + return newLoggingConn(baseName, c) +} + +var ExportAppendTime = appendTime + +func (t *Transport) NumPendingRequestsForTesting() int { + t.reqMu.Lock() + defer t.reqMu.Unlock() + return len(t.reqCanceler) +} + +func (t *Transport) IdleConnKeysForTesting() (keys []string) { + keys = make([]string, 0) + t.idleMu.Lock() + defer t.idleMu.Unlock() + if t.idleConn == nil { + return + } + for key := range t.idleConn { + keys = append(keys, key.String()) + } + return +} + +func (t *Transport) IdleConnCountForTesting(cacheKey string) int { + t.idleMu.Lock() + defer t.idleMu.Unlock() + if t.idleConn == nil { + return 0 + } + for k, conns := range t.idleConn { + if k.String() == cacheKey { + return len(conns) + } + } + return 0 +} + +func (t *Transport) IdleConnChMapSizeForTesting() int { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return len(t.idleConnCh) +} + +func (t *Transport) IsIdleForTesting() bool { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return t.wantIdle +} + +func (t *Transport) RequestIdleConnChForTesting() { + t.getIdleConnCh(connectMethod{nil, "http", "example.com"}) +} + +func (t *Transport) PutIdleTestConn() bool { + c, _ := net.Pipe() + return t.putIdleConn(&persistConn{ + t: t, + conn: c, // dummy + closech: make(chan struct{}), // so it can be closed + cacheKey: connectMethodKey{"", "http", "example.com"}, + }) +} + +func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { + f := func() <-chan time.Time { + return ch + } + return &timeoutHandler{handler, f, ""} +} + +func ResetCachedEnvironment() { + httpProxyEnv.reset() + httpsProxyEnv.reset() + noProxyEnv.reset() +} + +var DefaultUserAgent = defaultUserAgent + +func ExportRefererForURL(lastReq, newReq *url.URL) string { + return refererForURL(lastReq, newReq) +} + +// SetPendingDialHooks sets the hooks that run before and after handling +// pending dials. +func SetPendingDialHooks(before, after func()) { + prePendingDial, postPendingDial = before, after +} + +var ExportServerNewConn = (*Server).newConn + +var ExportCloseWriteAndWait = (*conn).closeWriteAndWait diff --git a/src/net/http/fcgi/child.go b/src/net/http/fcgi/child.go new file mode 100644 index 000000000..a3beaa33a --- /dev/null +++ b/src/net/http/fcgi/child.go @@ -0,0 +1,305 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fcgi + +// This file implements FastCGI from the perspective of a child process. + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/http/cgi" + "os" + "strings" + "sync" + "time" +) + +// request holds the state for an in-progress request. As soon as it's complete, +// it's converted to an http.Request. +type request struct { + pw *io.PipeWriter + reqId uint16 + params map[string]string + buf [1024]byte + rawParams []byte + keepConn bool +} + +func newRequest(reqId uint16, flags uint8) *request { + r := &request{ + reqId: reqId, + params: map[string]string{}, + keepConn: flags&flagKeepConn != 0, + } + r.rawParams = r.buf[:0] + return r +} + +// parseParams reads an encoded []byte into Params. +func (r *request) parseParams() { + text := r.rawParams + r.rawParams = nil + for len(text) > 0 { + keyLen, n := readSize(text) + if n == 0 { + return + } + text = text[n:] + valLen, n := readSize(text) + if n == 0 { + return + } + text = text[n:] + key := readString(text, keyLen) + text = text[keyLen:] + val := readString(text, valLen) + text = text[valLen:] + r.params[key] = val + } +} + +// response implements http.ResponseWriter. +type response struct { + req *request + header http.Header + w *bufWriter + wroteHeader bool +} + +func newResponse(c *child, req *request) *response { + return &response{ + req: req, + header: http.Header{}, + w: newWriter(c.conn, typeStdout, req.reqId), + } +} + +func (r *response) Header() http.Header { + return r.header +} + +func (r *response) Write(data []byte) (int, error) { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + return r.w.Write(data) +} + +func (r *response) WriteHeader(code int) { + if r.wroteHeader { + return + } + r.wroteHeader = true + if code == http.StatusNotModified { + // Must not have body. + r.header.Del("Content-Type") + r.header.Del("Content-Length") + r.header.Del("Transfer-Encoding") + } else if r.header.Get("Content-Type") == "" { + r.header.Set("Content-Type", "text/html; charset=utf-8") + } + + if r.header.Get("Date") == "" { + r.header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) + } + + fmt.Fprintf(r.w, "Status: %d %s\r\n", code, http.StatusText(code)) + r.header.Write(r.w) + r.w.WriteString("\r\n") +} + +func (r *response) Flush() { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + r.w.Flush() +} + +func (r *response) Close() error { + r.Flush() + return r.w.Close() +} + +type child struct { + conn *conn + handler http.Handler + + mu sync.Mutex // protects requests: + requests map[uint16]*request // keyed by request ID +} + +func newChild(rwc io.ReadWriteCloser, handler http.Handler) *child { + return &child{ + conn: newConn(rwc), + handler: handler, + requests: make(map[uint16]*request), + } +} + +func (c *child) serve() { + defer c.conn.Close() + var rec record + for { + if err := rec.read(c.conn.rwc); err != nil { + return + } + if err := c.handleRecord(&rec); err != nil { + return + } + } +} + +var errCloseConn = errors.New("fcgi: connection should be closed") + +var emptyBody = ioutil.NopCloser(strings.NewReader("")) + +func (c *child) handleRecord(rec *record) error { + c.mu.Lock() + req, ok := c.requests[rec.h.Id] + c.mu.Unlock() + if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues { + // The spec says to ignore unknown request IDs. + return nil + } + + switch rec.h.Type { + case typeBeginRequest: + if req != nil { + // The server is trying to begin a request with the same ID + // as an in-progress request. This is an error. + return errors.New("fcgi: received ID that is already in-flight") + } + + var br beginRequest + if err := br.read(rec.content()); err != nil { + return err + } + if br.role != roleResponder { + c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole) + return nil + } + req = newRequest(rec.h.Id, br.flags) + c.mu.Lock() + c.requests[rec.h.Id] = req + c.mu.Unlock() + return nil + case typeParams: + // NOTE(eds): Technically a key-value pair can straddle the boundary + // between two packets. We buffer until we've received all parameters. + if len(rec.content()) > 0 { + req.rawParams = append(req.rawParams, rec.content()...) + return nil + } + req.parseParams() + return nil + case typeStdin: + content := rec.content() + if req.pw == nil { + var body io.ReadCloser + if len(content) > 0 { + // body could be an io.LimitReader, but it shouldn't matter + // as long as both sides are behaving. + body, req.pw = io.Pipe() + } else { + body = emptyBody + } + go c.serveRequest(req, body) + } + if len(content) > 0 { + // TODO(eds): This blocks until the handler reads from the pipe. + // If the handler takes a long time, it might be a problem. + req.pw.Write(content) + } else if req.pw != nil { + req.pw.Close() + } + return nil + case typeGetValues: + values := map[string]string{"FCGI_MPXS_CONNS": "1"} + c.conn.writePairs(typeGetValuesResult, 0, values) + return nil + case typeData: + // If the filter role is implemented, read the data stream here. + return nil + case typeAbortRequest: + println("abort") + c.mu.Lock() + delete(c.requests, rec.h.Id) + c.mu.Unlock() + c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete) + if !req.keepConn { + // connection will close upon return + return errCloseConn + } + return nil + default: + b := make([]byte, 8) + b[0] = byte(rec.h.Type) + c.conn.writeRecord(typeUnknownType, 0, b) + return nil + } +} + +func (c *child) serveRequest(req *request, body io.ReadCloser) { + r := newResponse(c, req) + httpReq, err := cgi.RequestFromMap(req.params) + if err != nil { + // there was an error reading the request + r.WriteHeader(http.StatusInternalServerError) + c.conn.writeRecord(typeStderr, req.reqId, []byte(err.Error())) + } else { + httpReq.Body = body + c.handler.ServeHTTP(r, httpReq) + } + r.Close() + c.mu.Lock() + delete(c.requests, req.reqId) + c.mu.Unlock() + c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete) + + // Consume the entire body, so the host isn't still writing to + // us when we close the socket below in the !keepConn case, + // otherwise we'd send a RST. (golang.org/issue/4183) + // TODO(bradfitz): also bound this copy in time. Or send + // some sort of abort request to the host, so the host + // can properly cut off the client sending all the data. + // For now just bound it a little and + io.CopyN(ioutil.Discard, body, 100<<20) + body.Close() + + if !req.keepConn { + c.conn.Close() + } +} + +// Serve accepts incoming FastCGI connections on the listener l, creating a new +// goroutine for each. The goroutine reads requests and then calls handler +// to reply to them. +// If l is nil, Serve accepts connections from os.Stdin. +// If handler is nil, http.DefaultServeMux is used. +func Serve(l net.Listener, handler http.Handler) error { + if l == nil { + var err error + l, err = net.FileListener(os.Stdin) + if err != nil { + return err + } + defer l.Close() + } + if handler == nil { + handler = http.DefaultServeMux + } + for { + rw, err := l.Accept() + if err != nil { + return err + } + c := newChild(rw, handler) + go c.serve() + } +} diff --git a/src/net/http/fcgi/fcgi.go b/src/net/http/fcgi/fcgi.go new file mode 100644 index 000000000..06bba0488 --- /dev/null +++ b/src/net/http/fcgi/fcgi.go @@ -0,0 +1,274 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fcgi implements the FastCGI protocol. +// Currently only the responder role is supported. +// The protocol is defined at http://www.fastcgi.com/drupal/node/6?q=node/22 +package fcgi + +// This file defines the raw protocol and some utilities used by the child and +// the host. + +import ( + "bufio" + "bytes" + "encoding/binary" + "errors" + "io" + "sync" +) + +// recType is a record type, as defined by +// http://www.fastcgi.com/devkit/doc/fcgi-spec.html#S8 +type recType uint8 + +const ( + typeBeginRequest recType = 1 + typeAbortRequest recType = 2 + typeEndRequest recType = 3 + typeParams recType = 4 + typeStdin recType = 5 + typeStdout recType = 6 + typeStderr recType = 7 + typeData recType = 8 + typeGetValues recType = 9 + typeGetValuesResult recType = 10 + typeUnknownType recType = 11 +) + +// keep the connection between web-server and responder open after request +const flagKeepConn = 1 + +const ( + maxWrite = 65535 // maximum record body + maxPad = 255 +) + +const ( + roleResponder = iota + 1 // only Responders are implemented. + roleAuthorizer + roleFilter +) + +const ( + statusRequestComplete = iota + statusCantMultiplex + statusOverloaded + statusUnknownRole +) + +const headerLen = 8 + +type header struct { + Version uint8 + Type recType + Id uint16 + ContentLength uint16 + PaddingLength uint8 + Reserved uint8 +} + +type beginRequest struct { + role uint16 + flags uint8 + reserved [5]uint8 +} + +func (br *beginRequest) read(content []byte) error { + if len(content) != 8 { + return errors.New("fcgi: invalid begin request record") + } + br.role = binary.BigEndian.Uint16(content) + br.flags = content[2] + return nil +} + +// for padding so we don't have to allocate all the time +// not synchronized because we don't care what the contents are +var pad [maxPad]byte + +func (h *header) init(recType recType, reqId uint16, contentLength int) { + h.Version = 1 + h.Type = recType + h.Id = reqId + h.ContentLength = uint16(contentLength) + h.PaddingLength = uint8(-contentLength & 7) +} + +// conn sends records over rwc +type conn struct { + mutex sync.Mutex + rwc io.ReadWriteCloser + + // to avoid allocations + buf bytes.Buffer + h header +} + +func newConn(rwc io.ReadWriteCloser) *conn { + return &conn{rwc: rwc} +} + +func (c *conn) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.rwc.Close() +} + +type record struct { + h header + buf [maxWrite + maxPad]byte +} + +func (rec *record) read(r io.Reader) (err error) { + if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil { + return err + } + if rec.h.Version != 1 { + return errors.New("fcgi: invalid header version") + } + n := int(rec.h.ContentLength) + int(rec.h.PaddingLength) + if _, err = io.ReadFull(r, rec.buf[:n]); err != nil { + return err + } + return nil +} + +func (r *record) content() []byte { + return r.buf[:r.h.ContentLength] +} + +// writeRecord writes and sends a single record. +func (c *conn) writeRecord(recType recType, reqId uint16, b []byte) error { + c.mutex.Lock() + defer c.mutex.Unlock() + c.buf.Reset() + c.h.init(recType, reqId, len(b)) + if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil { + return err + } + if _, err := c.buf.Write(b); err != nil { + return err + } + if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil { + return err + } + _, err := c.rwc.Write(c.buf.Bytes()) + return err +} + +func (c *conn) writeBeginRequest(reqId uint16, role uint16, flags uint8) error { + b := [8]byte{byte(role >> 8), byte(role), flags} + return c.writeRecord(typeBeginRequest, reqId, b[:]) +} + +func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8) error { + b := make([]byte, 8) + binary.BigEndian.PutUint32(b, uint32(appStatus)) + b[4] = protocolStatus + return c.writeRecord(typeEndRequest, reqId, b) +} + +func (c *conn) writePairs(recType recType, reqId uint16, pairs map[string]string) error { + w := newWriter(c, recType, reqId) + b := make([]byte, 8) + for k, v := range pairs { + n := encodeSize(b, uint32(len(k))) + n += encodeSize(b[n:], uint32(len(v))) + if _, err := w.Write(b[:n]); err != nil { + return err + } + if _, err := w.WriteString(k); err != nil { + return err + } + if _, err := w.WriteString(v); err != nil { + return err + } + } + w.Close() + return nil +} + +func readSize(s []byte) (uint32, int) { + if len(s) == 0 { + return 0, 0 + } + size, n := uint32(s[0]), 1 + if size&(1<<7) != 0 { + if len(s) < 4 { + return 0, 0 + } + n = 4 + size = binary.BigEndian.Uint32(s) + size &^= 1 << 31 + } + return size, n +} + +func readString(s []byte, size uint32) string { + if size > uint32(len(s)) { + return "" + } + return string(s[:size]) +} + +func encodeSize(b []byte, size uint32) int { + if size > 127 { + size |= 1 << 31 + binary.BigEndian.PutUint32(b, size) + return 4 + } + b[0] = byte(size) + return 1 +} + +// bufWriter encapsulates bufio.Writer but also closes the underlying stream when +// Closed. +type bufWriter struct { + closer io.Closer + *bufio.Writer +} + +func (w *bufWriter) Close() error { + if err := w.Writer.Flush(); err != nil { + w.closer.Close() + return err + } + return w.closer.Close() +} + +func newWriter(c *conn, recType recType, reqId uint16) *bufWriter { + s := &streamWriter{c: c, recType: recType, reqId: reqId} + w := bufio.NewWriterSize(s, maxWrite) + return &bufWriter{s, w} +} + +// streamWriter abstracts out the separation of a stream into discrete records. +// It only writes maxWrite bytes at a time. +type streamWriter struct { + c *conn + recType recType + reqId uint16 +} + +func (w *streamWriter) Write(p []byte) (int, error) { + nn := 0 + for len(p) > 0 { + n := len(p) + if n > maxWrite { + n = maxWrite + } + if err := w.c.writeRecord(w.recType, w.reqId, p[:n]); err != nil { + return nn, err + } + nn += n + p = p[n:] + } + return nn, nil +} + +func (w *streamWriter) Close() error { + // send empty record to close the stream + return w.c.writeRecord(w.recType, w.reqId, nil) +} diff --git a/src/net/http/fcgi/fcgi_test.go b/src/net/http/fcgi/fcgi_test.go new file mode 100644 index 000000000..6c7e1a9ce --- /dev/null +++ b/src/net/http/fcgi/fcgi_test.go @@ -0,0 +1,150 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fcgi + +import ( + "bytes" + "errors" + "io" + "testing" +) + +var sizeTests = []struct { + size uint32 + bytes []byte +}{ + {0, []byte{0x00}}, + {127, []byte{0x7F}}, + {128, []byte{0x80, 0x00, 0x00, 0x80}}, + {1000, []byte{0x80, 0x00, 0x03, 0xE8}}, + {33554431, []byte{0x81, 0xFF, 0xFF, 0xFF}}, +} + +func TestSize(t *testing.T) { + b := make([]byte, 4) + for i, test := range sizeTests { + n := encodeSize(b, test.size) + if !bytes.Equal(b[:n], test.bytes) { + t.Errorf("%d expected %x, encoded %x", i, test.bytes, b) + } + size, n := readSize(test.bytes) + if size != test.size { + t.Errorf("%d expected %d, read %d", i, test.size, size) + } + if len(test.bytes) != n { + t.Errorf("%d did not consume all the bytes", i) + } + } +} + +var streamTests = []struct { + desc string + recType recType + reqId uint16 + content []byte + raw []byte +}{ + {"single record", typeStdout, 1, nil, + []byte{1, byte(typeStdout), 0, 1, 0, 0, 0, 0}, + }, + // this data will have to be split into two records + {"two records", typeStdin, 300, make([]byte, 66000), + bytes.Join([][]byte{ + // header for the first record + {1, byte(typeStdin), 0x01, 0x2C, 0xFF, 0xFF, 1, 0}, + make([]byte, 65536), + // header for the second + {1, byte(typeStdin), 0x01, 0x2C, 0x01, 0xD1, 7, 0}, + make([]byte, 472), + // header for the empty record + {1, byte(typeStdin), 0x01, 0x2C, 0, 0, 0, 0}, + }, + nil), + }, +} + +type nilCloser struct { + io.ReadWriter +} + +func (c *nilCloser) Close() error { return nil } + +func TestStreams(t *testing.T) { + var rec record +outer: + for _, test := range streamTests { + buf := bytes.NewBuffer(test.raw) + var content []byte + for buf.Len() > 0 { + if err := rec.read(buf); err != nil { + t.Errorf("%s: error reading record: %v", test.desc, err) + continue outer + } + content = append(content, rec.content()...) + } + if rec.h.Type != test.recType { + t.Errorf("%s: got type %d expected %d", test.desc, rec.h.Type, test.recType) + continue + } + if rec.h.Id != test.reqId { + t.Errorf("%s: got request ID %d expected %d", test.desc, rec.h.Id, test.reqId) + continue + } + if !bytes.Equal(content, test.content) { + t.Errorf("%s: read wrong content", test.desc) + continue + } + buf.Reset() + c := newConn(&nilCloser{buf}) + w := newWriter(c, test.recType, test.reqId) + if _, err := w.Write(test.content); err != nil { + t.Errorf("%s: error writing record: %v", test.desc, err) + continue + } + if err := w.Close(); err != nil { + t.Errorf("%s: error closing stream: %v", test.desc, err) + continue + } + if !bytes.Equal(buf.Bytes(), test.raw) { + t.Errorf("%s: wrote wrong content", test.desc) + } + } +} + +type writeOnlyConn struct { + buf []byte +} + +func (c *writeOnlyConn) Write(p []byte) (int, error) { + c.buf = append(c.buf, p...) + return len(p), nil +} + +func (c *writeOnlyConn) Read(p []byte) (int, error) { + return 0, errors.New("conn is write-only") +} + +func (c *writeOnlyConn) Close() error { + return nil +} + +func TestGetValues(t *testing.T) { + var rec record + rec.h.Type = typeGetValues + + wc := new(writeOnlyConn) + c := newChild(wc, nil) + err := c.handleRecord(&rec) + if err != nil { + t.Fatalf("handleRecord: %v", err) + } + + const want = "\x01\n\x00\x00\x00\x12\x06\x00" + + "\x0f\x01FCGI_MPXS_CONNS1" + + "\x00\x00\x00\x00\x00\x00\x01\n\x00\x00\x00\x00\x00\x00" + if got := string(wc.buf); got != want { + t.Errorf(" got: %q\nwant: %q\n", got, want) + } +} diff --git a/src/net/http/filetransport.go b/src/net/http/filetransport.go new file mode 100644 index 000000000..821787e0c --- /dev/null +++ b/src/net/http/filetransport.go @@ -0,0 +1,123 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "fmt" + "io" +) + +// fileTransport implements RoundTripper for the 'file' protocol. +type fileTransport struct { + fh fileHandler +} + +// NewFileTransport returns a new RoundTripper, serving the provided +// FileSystem. The returned RoundTripper ignores the URL host in its +// incoming requests, as well as most other properties of the +// request. +// +// The typical use case for NewFileTransport is to register the "file" +// protocol with a Transport, as in: +// +// t := &http.Transport{} +// t.RegisterProtocol("file", http.NewFileTransport(http.Dir("/"))) +// c := &http.Client{Transport: t} +// res, err := c.Get("file:///etc/passwd") +// ... +func NewFileTransport(fs FileSystem) RoundTripper { + return fileTransport{fileHandler{fs}} +} + +func (t fileTransport) RoundTrip(req *Request) (resp *Response, err error) { + // We start ServeHTTP in a goroutine, which may take a long + // time if the file is large. The newPopulateResponseWriter + // call returns a channel which either ServeHTTP or finish() + // sends our *Response on, once the *Response itself has been + // populated (even if the body itself is still being + // written to the res.Body, a pipe) + rw, resc := newPopulateResponseWriter() + go func() { + t.fh.ServeHTTP(rw, req) + rw.finish() + }() + return <-resc, nil +} + +func newPopulateResponseWriter() (*populateResponse, <-chan *Response) { + pr, pw := io.Pipe() + rw := &populateResponse{ + ch: make(chan *Response), + pw: pw, + res: &Response{ + Proto: "HTTP/1.0", + ProtoMajor: 1, + Header: make(Header), + Close: true, + Body: pr, + }, + } + return rw, rw.ch +} + +// populateResponse is a ResponseWriter that populates the *Response +// in res, and writes its body to a pipe connected to the response +// body. Once writes begin or finish() is called, the response is sent +// on ch. +type populateResponse struct { + res *Response + ch chan *Response + wroteHeader bool + hasContent bool + sentResponse bool + pw *io.PipeWriter +} + +func (pr *populateResponse) finish() { + if !pr.wroteHeader { + pr.WriteHeader(500) + } + if !pr.sentResponse { + pr.sendResponse() + } + pr.pw.Close() +} + +func (pr *populateResponse) sendResponse() { + if pr.sentResponse { + return + } + pr.sentResponse = true + + if pr.hasContent { + pr.res.ContentLength = -1 + } + pr.ch <- pr.res +} + +func (pr *populateResponse) Header() Header { + return pr.res.Header +} + +func (pr *populateResponse) WriteHeader(code int) { + if pr.wroteHeader { + return + } + pr.wroteHeader = true + + pr.res.StatusCode = code + pr.res.Status = fmt.Sprintf("%d %s", code, StatusText(code)) +} + +func (pr *populateResponse) Write(p []byte) (n int, err error) { + if !pr.wroteHeader { + pr.WriteHeader(StatusOK) + } + pr.hasContent = true + if !pr.sentResponse { + pr.sendResponse() + } + return pr.pw.Write(p) +} diff --git a/src/net/http/filetransport_test.go b/src/net/http/filetransport_test.go new file mode 100644 index 000000000..6f1a537e2 --- /dev/null +++ b/src/net/http/filetransport_test.go @@ -0,0 +1,65 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" +) + +func checker(t *testing.T) func(string, error) { + return func(call string, err error) { + if err == nil { + return + } + t.Fatalf("%s: %v", call, err) + } +} + +func TestFileTransport(t *testing.T) { + check := checker(t) + + dname, err := ioutil.TempDir("", "") + check("TempDir", err) + fname := filepath.Join(dname, "foo.txt") + err = ioutil.WriteFile(fname, []byte("Bar"), 0644) + check("WriteFile", err) + defer os.Remove(dname) + defer os.Remove(fname) + + tr := &Transport{} + tr.RegisterProtocol("file", NewFileTransport(Dir(dname))) + c := &Client{Transport: tr} + + fooURLs := []string{"file:///foo.txt", "file://../foo.txt"} + for _, urlstr := range fooURLs { + res, err := c.Get(urlstr) + check("Get "+urlstr, err) + if res.StatusCode != 200 { + t.Errorf("for %s, StatusCode = %d, want 200", urlstr, res.StatusCode) + } + if res.ContentLength != -1 { + t.Errorf("for %s, ContentLength = %d, want -1", urlstr, res.ContentLength) + } + if res.Body == nil { + t.Fatalf("for %s, nil Body", urlstr) + } + slurp, err := ioutil.ReadAll(res.Body) + check("ReadAll "+urlstr, err) + if string(slurp) != "Bar" { + t.Errorf("for %s, got content %q, want %q", urlstr, string(slurp), "Bar") + } + } + + const badURL = "file://../no-exist.txt" + res, err := c.Get(badURL) + check("Get "+badURL, err) + if res.StatusCode != 404 { + t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode) + } + res.Body.Close() +} diff --git a/src/net/http/fs.go b/src/net/http/fs.go new file mode 100644 index 000000000..e322f710a --- /dev/null +++ b/src/net/http/fs.go @@ -0,0 +1,556 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP file system request handler + +package http + +import ( + "errors" + "fmt" + "io" + "mime" + "mime/multipart" + "net/textproto" + "net/url" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "time" +) + +// A Dir implements FileSystem using the native file system restricted to a +// specific directory tree. +// +// While the FileSystem.Open method takes '/'-separated paths, a Dir's string +// value is a filename on the native file system, not a URL, so it is separated +// by filepath.Separator, which isn't necessarily '/'. +// +// An empty Dir is treated as ".". +type Dir string + +func (d Dir) Open(name string) (File, error) { + if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 || + strings.Contains(name, "\x00") { + return nil, errors.New("http: invalid character in file path") + } + dir := string(d) + if dir == "" { + dir = "." + } + f, err := os.Open(filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name)))) + if err != nil { + return nil, err + } + return f, nil +} + +// A FileSystem implements access to a collection of named files. +// The elements in a file path are separated by slash ('/', U+002F) +// characters, regardless of host operating system convention. +type FileSystem interface { + Open(name string) (File, error) +} + +// A File is returned by a FileSystem's Open method and can be +// served by the FileServer implementation. +// +// The methods should behave the same as those on an *os.File. +type File interface { + io.Closer + io.Reader + Readdir(count int) ([]os.FileInfo, error) + Seek(offset int64, whence int) (int64, error) + Stat() (os.FileInfo, error) +} + +func dirList(w ResponseWriter, f File) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + fmt.Fprintf(w, "<pre>\n") + for { + dirs, err := f.Readdir(100) + if err != nil || len(dirs) == 0 { + break + } + for _, d := range dirs { + name := d.Name() + if d.IsDir() { + name += "/" + } + // name may contain '?' or '#', which must be escaped to remain + // part of the URL path, and not indicate the start of a query + // string or fragment. + url := url.URL{Path: name} + fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", url.String(), htmlReplacer.Replace(name)) + } + } + fmt.Fprintf(w, "</pre>\n") +} + +// ServeContent replies to the request using the content in the +// provided ReadSeeker. The main benefit of ServeContent over io.Copy +// is that it handles Range requests properly, sets the MIME type, and +// handles If-Modified-Since requests. +// +// If the response's Content-Type header is not set, ServeContent +// first tries to deduce the type from name's file extension and, +// if that fails, falls back to reading the first block of the content +// and passing it to DetectContentType. +// The name is otherwise unused; in particular it can be empty and is +// never sent in the response. +// +// If modtime is not the zero time, ServeContent includes it in a +// Last-Modified header in the response. If the request includes an +// If-Modified-Since header, ServeContent uses modtime to decide +// whether the content needs to be sent at all. +// +// The content's Seek method must work: ServeContent uses +// a seek to the end of the content to determine its size. +// +// If the caller has set w's ETag header, ServeContent uses it to +// handle requests using If-Range and If-None-Match. +// +// Note that *os.File implements the io.ReadSeeker interface. +func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) { + sizeFunc := func() (int64, error) { + size, err := content.Seek(0, os.SEEK_END) + if err != nil { + return 0, errSeeker + } + _, err = content.Seek(0, os.SEEK_SET) + if err != nil { + return 0, errSeeker + } + return size, nil + } + serveContent(w, req, name, modtime, sizeFunc, content) +} + +// errSeeker is returned by ServeContent's sizeFunc when the content +// doesn't seek properly. The underlying Seeker's error text isn't +// included in the sizeFunc reply so it's not sent over HTTP to end +// users. +var errSeeker = errors.New("seeker can't seek") + +// if name is empty, filename is unknown. (used for mime type, before sniffing) +// if modtime.IsZero(), modtime is unknown. +// content must be seeked to the beginning of the file. +// The sizeFunc is called at most once. Its error, if any, is sent in the HTTP response. +func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, sizeFunc func() (int64, error), content io.ReadSeeker) { + if checkLastModified(w, r, modtime) { + return + } + rangeReq, done := checkETag(w, r, modtime) + if done { + return + } + + code := StatusOK + + // If Content-Type isn't set, use the file's extension to find it, but + // if the Content-Type is unset explicitly, do not sniff the type. + ctypes, haveType := w.Header()["Content-Type"] + var ctype string + if !haveType { + ctype = mime.TypeByExtension(filepath.Ext(name)) + if ctype == "" { + // read a chunk to decide between utf-8 text and binary + var buf [sniffLen]byte + n, _ := io.ReadFull(content, buf[:]) + ctype = DetectContentType(buf[:n]) + _, err := content.Seek(0, os.SEEK_SET) // rewind to output whole file + if err != nil { + Error(w, "seeker can't seek", StatusInternalServerError) + return + } + } + w.Header().Set("Content-Type", ctype) + } else if len(ctypes) > 0 { + ctype = ctypes[0] + } + + size, err := sizeFunc() + if err != nil { + Error(w, err.Error(), StatusInternalServerError) + return + } + + // handle Content-Range header. + sendSize := size + var sendContent io.Reader = content + if size >= 0 { + ranges, err := parseRange(rangeReq, size) + if err != nil { + Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) + return + } + if sumRangesSize(ranges) > size { + // The total number of bytes in all the ranges + // is larger than the size of the file by + // itself, so this is probably an attack, or a + // dumb client. Ignore the range request. + ranges = nil + } + switch { + case len(ranges) == 1: + // RFC 2616, Section 14.16: + // "When an HTTP message includes the content of a single + // range (for example, a response to a request for a + // single range, or to a request for a set of ranges + // that overlap without any holes), this content is + // transmitted with a Content-Range header, and a + // Content-Length header showing the number of bytes + // actually transferred. + // ... + // A response to a request for a single range MUST NOT + // be sent using the multipart/byteranges media type." + ra := ranges[0] + if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil { + Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) + return + } + sendSize = ra.length + code = StatusPartialContent + w.Header().Set("Content-Range", ra.contentRange(size)) + case len(ranges) > 1: + sendSize = rangesMIMESize(ranges, ctype, size) + code = StatusPartialContent + + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) + sendContent = pr + defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. + go func() { + for _, ra := range ranges { + part, err := mw.CreatePart(ra.mimeHeader(ctype, size)) + if err != nil { + pw.CloseWithError(err) + return + } + if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil { + pw.CloseWithError(err) + return + } + if _, err := io.CopyN(part, content, ra.length); err != nil { + pw.CloseWithError(err) + return + } + } + mw.Close() + pw.Close() + }() + } + + w.Header().Set("Accept-Ranges", "bytes") + if w.Header().Get("Content-Encoding") == "" { + w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10)) + } + } + + w.WriteHeader(code) + + if r.Method != "HEAD" { + io.CopyN(w, sendContent, sendSize) + } +} + +// modtime is the modification time of the resource to be served, or IsZero(). +// return value is whether this request is now complete. +func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool { + if modtime.IsZero() { + return false + } + + // The Date-Modified header truncates sub-second precision, so + // use mtime < t+1s instead of mtime <= t to check for unmodified. + if t, err := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modtime.Before(t.Add(1*time.Second)) { + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + w.WriteHeader(StatusNotModified) + return true + } + w.Header().Set("Last-Modified", modtime.UTC().Format(TimeFormat)) + return false +} + +// checkETag implements If-None-Match and If-Range checks. +// +// The ETag or modtime must have been previously set in the +// ResponseWriter's headers. The modtime is only compared at second +// granularity and may be the zero value to mean unknown. +// +// The return value is the effective request "Range" header to use and +// whether this request is now considered done. +func checkETag(w ResponseWriter, r *Request, modtime time.Time) (rangeReq string, done bool) { + etag := w.Header().get("Etag") + rangeReq = r.Header.get("Range") + + // Invalidate the range request if the entity doesn't match the one + // the client was expecting. + // "If-Range: version" means "ignore the Range: header unless version matches the + // current file." + // We only support ETag versions. + // The caller must have set the ETag on the response already. + if ir := r.Header.get("If-Range"); ir != "" && ir != etag { + // The If-Range value is typically the ETag value, but it may also be + // the modtime date. See golang.org/issue/8367. + timeMatches := false + if !modtime.IsZero() { + if t, err := ParseTime(ir); err == nil && t.Unix() == modtime.Unix() { + timeMatches = true + } + } + if !timeMatches { + rangeReq = "" + } + } + + if inm := r.Header.get("If-None-Match"); inm != "" { + // Must know ETag. + if etag == "" { + return rangeReq, false + } + + // TODO(bradfitz): non-GET/HEAD requests require more work: + // sending a different status code on matches, and + // also can't use weak cache validators (those with a "W/ + // prefix). But most users of ServeContent will be using + // it on GET or HEAD, so only support those for now. + if r.Method != "GET" && r.Method != "HEAD" { + return rangeReq, false + } + + // TODO(bradfitz): deal with comma-separated or multiple-valued + // list of If-None-match values. For now just handle the common + // case of a single item. + if inm == etag || inm == "*" { + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + w.WriteHeader(StatusNotModified) + return "", true + } + } + return rangeReq, false +} + +// name is '/'-separated, not filepath.Separator. +func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) { + const indexPage = "/index.html" + + // redirect .../index.html to .../ + // can't use Redirect() because that would make the path absolute, + // which would be a problem running under StripPrefix + if strings.HasSuffix(r.URL.Path, indexPage) { + localRedirect(w, r, "./") + return + } + + f, err := fs.Open(name) + if err != nil { + // TODO expose actual error? + NotFound(w, r) + return + } + defer f.Close() + + d, err1 := f.Stat() + if err1 != nil { + // TODO expose actual error? + NotFound(w, r) + return + } + + if redirect { + // redirect to canonical path: / at end of directory url + // r.URL.Path always begins with / + url := r.URL.Path + if d.IsDir() { + if url[len(url)-1] != '/' { + localRedirect(w, r, path.Base(url)+"/") + return + } + } else { + if url[len(url)-1] == '/' { + localRedirect(w, r, "../"+path.Base(url)) + return + } + } + } + + // use contents of index.html for directory, if present + if d.IsDir() { + index := strings.TrimSuffix(name, "/") + indexPage + ff, err := fs.Open(index) + if err == nil { + defer ff.Close() + dd, err := ff.Stat() + if err == nil { + name = index + d = dd + f = ff + } + } + } + + // Still a directory? (we didn't find an index.html file) + if d.IsDir() { + if checkLastModified(w, r, d.ModTime()) { + return + } + dirList(w, f) + return + } + + // serveContent will check modification time + sizeFunc := func() (int64, error) { return d.Size(), nil } + serveContent(w, r, d.Name(), d.ModTime(), sizeFunc, f) +} + +// localRedirect gives a Moved Permanently response. +// It does not convert relative paths to absolute paths like Redirect does. +func localRedirect(w ResponseWriter, r *Request, newPath string) { + if q := r.URL.RawQuery; q != "" { + newPath += "?" + q + } + w.Header().Set("Location", newPath) + w.WriteHeader(StatusMovedPermanently) +} + +// ServeFile replies to the request with the contents of the named file or directory. +func ServeFile(w ResponseWriter, r *Request, name string) { + dir, file := filepath.Split(name) + serveFile(w, r, Dir(dir), file, false) +} + +type fileHandler struct { + root FileSystem +} + +// FileServer returns a handler that serves HTTP requests +// with the contents of the file system rooted at root. +// +// To use the operating system's file system implementation, +// use http.Dir: +// +// http.Handle("/", http.FileServer(http.Dir("/tmp"))) +func FileServer(root FileSystem) Handler { + return &fileHandler{root} +} + +func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) { + upath := r.URL.Path + if !strings.HasPrefix(upath, "/") { + upath = "/" + upath + r.URL.Path = upath + } + serveFile(w, r, f.root, path.Clean(upath), true) +} + +// httpRange specifies the byte range to be sent to the client. +type httpRange struct { + start, length int64 +} + +func (r httpRange) contentRange(size int64) string { + return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size) +} + +func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader { + return textproto.MIMEHeader{ + "Content-Range": {r.contentRange(size)}, + "Content-Type": {contentType}, + } +} + +// parseRange parses a Range header string as per RFC 2616. +func parseRange(s string, size int64) ([]httpRange, error) { + if s == "" { + return nil, nil // header not present + } + const b = "bytes=" + if !strings.HasPrefix(s, b) { + return nil, errors.New("invalid range") + } + var ranges []httpRange + for _, ra := range strings.Split(s[len(b):], ",") { + ra = strings.TrimSpace(ra) + if ra == "" { + continue + } + i := strings.Index(ra, "-") + if i < 0 { + return nil, errors.New("invalid range") + } + start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:]) + var r httpRange + if start == "" { + // If no start is specified, end specifies the + // range start relative to the end of the file. + i, err := strconv.ParseInt(end, 10, 64) + if err != nil { + return nil, errors.New("invalid range") + } + if i > size { + i = size + } + r.start = size - i + r.length = size - r.start + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i > size || i < 0 { + return nil, errors.New("invalid range") + } + r.start = i + if end == "" { + // If no end is specified, range extends to end of the file. + r.length = size - r.start + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || r.start > i { + return nil, errors.New("invalid range") + } + if i >= size { + i = size - 1 + } + r.length = i - r.start + 1 + } + } + ranges = append(ranges, r) + } + return ranges, nil +} + +// countingWriter counts how many bytes have been written to it. +type countingWriter int64 + +func (w *countingWriter) Write(p []byte) (n int, err error) { + *w += countingWriter(len(p)) + return len(p), nil +} + +// rangesMIMESize returns the number of bytes it takes to encode the +// provided ranges as a multipart response. +func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) { + var w countingWriter + mw := multipart.NewWriter(&w) + for _, ra := range ranges { + mw.CreatePart(ra.mimeHeader(contentType, contentSize)) + encSize += ra.length + } + mw.Close() + encSize += int64(w) + return +} + +func sumRangesSize(ranges []httpRange) (size int64) { + for _, ra := range ranges { + size += ra.length + } + return +} diff --git a/src/net/http/fs_test.go b/src/net/http/fs_test.go new file mode 100644 index 000000000..8770d9b41 --- /dev/null +++ b/src/net/http/fs_test.go @@ -0,0 +1,917 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "mime" + "mime/multipart" + "net" + . "net/http" + "net/http/httptest" + "net/url" + "os" + "os/exec" + "path" + "path/filepath" + "reflect" + "regexp" + "runtime" + "strconv" + "strings" + "testing" + "time" +) + +const ( + testFile = "testdata/file" + testFileLen = 11 +) + +type wantRange struct { + start, end int64 // range [start,end) +} + +var itoa = strconv.Itoa + +var ServeFileRangeTests = []struct { + r string + code int + ranges []wantRange +}{ + {r: "", code: StatusOK}, + {r: "bytes=0-4", code: StatusPartialContent, ranges: []wantRange{{0, 5}}}, + {r: "bytes=2-", code: StatusPartialContent, ranges: []wantRange{{2, testFileLen}}}, + {r: "bytes=-5", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 5, testFileLen}}}, + {r: "bytes=3-7", code: StatusPartialContent, ranges: []wantRange{{3, 8}}}, + {r: "bytes=20-", code: StatusRequestedRangeNotSatisfiable}, + {r: "bytes=0-0,-2", code: StatusPartialContent, ranges: []wantRange{{0, 1}, {testFileLen - 2, testFileLen}}}, + {r: "bytes=0-1,5-8", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, 9}}}, + {r: "bytes=0-1,5-", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, testFileLen}}}, + {r: "bytes=5-1000", code: StatusPartialContent, ranges: []wantRange{{5, testFileLen}}}, + {r: "bytes=0-,1-,2-,3-,4-", code: StatusOK}, // ignore wasteful range request + {r: "bytes=0-" + itoa(testFileLen-2), code: StatusPartialContent, ranges: []wantRange{{0, testFileLen - 1}}}, + {r: "bytes=0-" + itoa(testFileLen-1), code: StatusPartialContent, ranges: []wantRange{{0, testFileLen}}}, + {r: "bytes=0-" + itoa(testFileLen), code: StatusPartialContent, ranges: []wantRange{{0, testFileLen}}}, +} + +func TestServeFile(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "testdata/file") + })) + defer ts.Close() + + var err error + + file, err := ioutil.ReadFile(testFile) + if err != nil { + t.Fatal("reading file:", err) + } + + // set up the Request (re-used for all tests) + var req Request + req.Header = make(Header) + if req.URL, err = url.Parse(ts.URL); err != nil { + t.Fatal("ParseURL:", err) + } + req.Method = "GET" + + // straight GET + _, body := getBody(t, "straight get", req) + if !bytes.Equal(body, file) { + t.Fatalf("body mismatch: got %q, want %q", body, file) + } + + // Range tests +Cases: + for _, rt := range ServeFileRangeTests { + if rt.r != "" { + req.Header.Set("Range", rt.r) + } + resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req) + if resp.StatusCode != rt.code { + t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code) + } + if rt.code == StatusRequestedRangeNotSatisfiable { + continue + } + wantContentRange := "" + if len(rt.ranges) == 1 { + rng := rt.ranges[0] + wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen) + } + cr := resp.Header.Get("Content-Range") + if cr != wantContentRange { + t.Errorf("range=%q: Content-Range = %q, want %q", rt.r, cr, wantContentRange) + } + ct := resp.Header.Get("Content-Type") + if len(rt.ranges) == 1 { + rng := rt.ranges[0] + wantBody := file[rng.start:rng.end] + if !bytes.Equal(body, wantBody) { + t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody) + } + if strings.HasPrefix(ct, "multipart/byteranges") { + t.Errorf("range=%q content-type = %q; unexpected multipart/byteranges", rt.r, ct) + } + } + if len(rt.ranges) > 1 { + typ, params, err := mime.ParseMediaType(ct) + if err != nil { + t.Errorf("range=%q content-type = %q; %v", rt.r, ct, err) + continue + } + if typ != "multipart/byteranges" { + t.Errorf("range=%q content-type = %q; want multipart/byteranges", rt.r, typ) + continue + } + if params["boundary"] == "" { + t.Errorf("range=%q content-type = %q; lacks boundary", rt.r, ct) + continue + } + if g, w := resp.ContentLength, int64(len(body)); g != w { + t.Errorf("range=%q Content-Length = %d; want %d", rt.r, g, w) + continue + } + mr := multipart.NewReader(bytes.NewReader(body), params["boundary"]) + for ri, rng := range rt.ranges { + part, err := mr.NextPart() + if err != nil { + t.Errorf("range=%q, reading part index %d: %v", rt.r, ri, err) + continue Cases + } + wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen) + if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w { + t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w) + } + body, err := ioutil.ReadAll(part) + if err != nil { + t.Errorf("range=%q, reading part index %d body: %v", rt.r, ri, err) + continue Cases + } + wantBody := file[rng.start:rng.end] + if !bytes.Equal(body, wantBody) { + t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody) + } + } + _, err = mr.NextPart() + if err != io.EOF { + t.Errorf("range=%q; expected final error io.EOF; got %v", rt.r, err) + } + } + } +} + +var fsRedirectTestData = []struct { + original, redirect string +}{ + {"/test/index.html", "/test/"}, + {"/test/testdata", "/test/testdata/"}, + {"/test/testdata/file/", "/test/testdata/file"}, +} + +func TestFSRedirect(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir(".")))) + defer ts.Close() + + for _, data := range fsRedirectTestData { + res, err := Get(ts.URL + data.original) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if g, e := res.Request.URL.Path, data.redirect; g != e { + t.Errorf("redirect from %s: got %s, want %s", data.original, g, e) + } + } +} + +type testFileSystem struct { + open func(name string) (File, error) +} + +func (fs *testFileSystem) Open(name string) (File, error) { + return fs.open(name) +} + +func TestFileServerCleans(t *testing.T) { + defer afterTest(t) + ch := make(chan string, 1) + fs := FileServer(&testFileSystem{func(name string) (File, error) { + ch <- name + return nil, errors.New("file does not exist") + }}) + tests := []struct { + reqPath, openArg string + }{ + {"/foo.txt", "/foo.txt"}, + {"//foo.txt", "/foo.txt"}, + {"/../foo.txt", "/foo.txt"}, + } + req, _ := NewRequest("GET", "http://example.com", nil) + for n, test := range tests { + rec := httptest.NewRecorder() + req.URL.Path = test.reqPath + fs.ServeHTTP(rec, req) + if got := <-ch; got != test.openArg { + t.Errorf("test %d: got %q, want %q", n, got, test.openArg) + } + } +} + +func TestFileServerEscapesNames(t *testing.T) { + defer afterTest(t) + const dirListPrefix = "<pre>\n" + const dirListSuffix = "\n</pre>\n" + tests := []struct { + name, escaped string + }{ + {`simple_name`, `<a href="simple_name">simple_name</a>`}, + {`"'<>&`, `<a href="%22%27%3C%3E&">"'<>&</a>`}, + {`?foo=bar#baz`, `<a href="%3Ffoo=bar%23baz">?foo=bar#baz</a>`}, + {`<combo>?foo`, `<a href="%3Ccombo%3E%3Ffoo"><combo>?foo</a>`}, + } + + // We put each test file in its own directory in the fakeFS so we can look at it in isolation. + fs := make(fakeFS) + for i, test := range tests { + testFile := &fakeFileInfo{basename: test.name} + fs[fmt.Sprintf("/%d", i)] = &fakeFileInfo{ + dir: true, + modtime: time.Unix(1000000000, 0).UTC(), + ents: []*fakeFileInfo{testFile}, + } + fs[fmt.Sprintf("/%d/%s", i, test.name)] = testFile + } + + ts := httptest.NewServer(FileServer(&fs)) + defer ts.Close() + for i, test := range tests { + url := fmt.Sprintf("%s/%d", ts.URL, i) + res, err := Get(url) + if err != nil { + t.Fatalf("test %q: Get: %v", test.name, err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("test %q: read Body: %v", test.name, err) + } + s := string(b) + if !strings.HasPrefix(s, dirListPrefix) || !strings.HasSuffix(s, dirListSuffix) { + t.Errorf("test %q: listing dir, full output is %q, want prefix %q and suffix %q", test.name, s, dirListPrefix, dirListSuffix) + } + if trimmed := strings.TrimSuffix(strings.TrimPrefix(s, dirListPrefix), dirListSuffix); trimmed != test.escaped { + t.Errorf("test %q: listing dir, filename escaped to %q, want %q", test.name, trimmed, test.escaped) + } + res.Body.Close() + } +} + +func mustRemoveAll(dir string) { + err := os.RemoveAll(dir) + if err != nil { + panic(err) + } +} + +func TestFileServerImplicitLeadingSlash(t *testing.T) { + defer afterTest(t) + tempDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatalf("TempDir: %v", err) + } + defer mustRemoveAll(tempDir) + if err := ioutil.WriteFile(filepath.Join(tempDir, "foo.txt"), []byte("Hello world"), 0644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + ts := httptest.NewServer(StripPrefix("/bar/", FileServer(Dir(tempDir)))) + defer ts.Close() + get := func(suffix string) string { + res, err := Get(ts.URL + suffix) + if err != nil { + t.Fatalf("Get %s: %v", suffix, err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("ReadAll %s: %v", suffix, err) + } + res.Body.Close() + return string(b) + } + if s := get("/bar/"); !strings.Contains(s, ">foo.txt<") { + t.Logf("expected a directory listing with foo.txt, got %q", s) + } + if s := get("/bar/foo.txt"); s != "Hello world" { + t.Logf("expected %q, got %q", "Hello world", s) + } +} + +func TestDirJoin(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping test on windows") + } + wfi, err := os.Stat("/etc/hosts") + if err != nil { + t.Skip("skipping test; no /etc/hosts file") + } + test := func(d Dir, name string) { + f, err := d.Open(name) + if err != nil { + t.Fatalf("open of %s: %v", name, err) + } + defer f.Close() + gfi, err := f.Stat() + if err != nil { + t.Fatalf("stat of %s: %v", name, err) + } + if !os.SameFile(gfi, wfi) { + t.Errorf("%s got different file", name) + } + } + test(Dir("/etc/"), "/hosts") + test(Dir("/etc/"), "hosts") + test(Dir("/etc/"), "../../../../hosts") + test(Dir("/etc"), "/hosts") + test(Dir("/etc"), "hosts") + test(Dir("/etc"), "../../../../hosts") + + // Not really directories, but since we use this trick in + // ServeFile, test it: + test(Dir("/etc/hosts"), "") + test(Dir("/etc/hosts"), "/") + test(Dir("/etc/hosts"), "../") +} + +func TestEmptyDirOpenCWD(t *testing.T) { + test := func(d Dir) { + name := "fs_test.go" + f, err := d.Open(name) + if err != nil { + t.Fatalf("open of %s: %v", name, err) + } + defer f.Close() + } + test(Dir("")) + test(Dir(".")) + test(Dir("./")) +} + +func TestServeFileContentType(t *testing.T) { + defer afterTest(t) + const ctype = "icecream/chocolate" + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + switch r.FormValue("override") { + case "1": + w.Header().Set("Content-Type", ctype) + case "2": + // Explicitly inhibit sniffing. + w.Header()["Content-Type"] = []string{} + } + ServeFile(w, r, "testdata/file") + })) + defer ts.Close() + get := func(override string, want []string) { + resp, err := Get(ts.URL + "?override=" + override) + if err != nil { + t.Fatal(err) + } + if h := resp.Header["Content-Type"]; !reflect.DeepEqual(h, want) { + t.Errorf("Content-Type mismatch: got %v, want %v", h, want) + } + resp.Body.Close() + } + get("0", []string{"text/plain; charset=utf-8"}) + get("1", []string{ctype}) + get("2", nil) +} + +func TestServeFileMimeType(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "testdata/style.css") + })) + defer ts.Close() + resp, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + want := "text/css; charset=utf-8" + if h := resp.Header.Get("Content-Type"); h != want { + t.Errorf("Content-Type mismatch: got %q, want %q", h, want) + } +} + +func TestServeFileFromCWD(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "fs_test.go") + })) + defer ts.Close() + r, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + r.Body.Close() + if r.StatusCode != 200 { + t.Fatalf("expected 200 OK, got %s", r.Status) + } +} + +func TestServeFileWithContentEncoding(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "foo") + ServeFile(w, r, "testdata/file") + })) + defer ts.Close() + resp, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if g, e := resp.ContentLength, int64(-1); g != e { + t.Errorf("Content-Length mismatch: got %d, want %d", g, e) + } +} + +func TestServeIndexHtml(t *testing.T) { + defer afterTest(t) + const want = "index.html says hello\n" + ts := httptest.NewServer(FileServer(Dir("."))) + defer ts.Close() + + for _, path := range []string{"/testdata/", "/testdata/index.html"} { + res, err := Get(ts.URL + path) + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal("reading Body:", err) + } + if s := string(b); s != want { + t.Errorf("for path %q got %q, want %q", path, s, want) + } + res.Body.Close() + } +} + +func TestFileServerZeroByte(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(FileServer(Dir("."))) + defer ts.Close() + + res, err := Get(ts.URL + "/..\x00") + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal("reading Body:", err) + } + if res.StatusCode == 200 { + t.Errorf("got status 200; want an error. Body is:\n%s", string(b)) + } +} + +type fakeFileInfo struct { + dir bool + basename string + modtime time.Time + ents []*fakeFileInfo + contents string +} + +func (f *fakeFileInfo) Name() string { return f.basename } +func (f *fakeFileInfo) Sys() interface{} { return nil } +func (f *fakeFileInfo) ModTime() time.Time { return f.modtime } +func (f *fakeFileInfo) IsDir() bool { return f.dir } +func (f *fakeFileInfo) Size() int64 { return int64(len(f.contents)) } +func (f *fakeFileInfo) Mode() os.FileMode { + if f.dir { + return 0755 | os.ModeDir + } + return 0644 +} + +type fakeFile struct { + io.ReadSeeker + fi *fakeFileInfo + path string // as opened + entpos int +} + +func (f *fakeFile) Close() error { return nil } +func (f *fakeFile) Stat() (os.FileInfo, error) { return f.fi, nil } +func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) { + if !f.fi.dir { + return nil, os.ErrInvalid + } + var fis []os.FileInfo + + limit := f.entpos + count + if count <= 0 || limit > len(f.fi.ents) { + limit = len(f.fi.ents) + } + for ; f.entpos < limit; f.entpos++ { + fis = append(fis, f.fi.ents[f.entpos]) + } + + if len(fis) == 0 && count > 0 { + return fis, io.EOF + } else { + return fis, nil + } +} + +type fakeFS map[string]*fakeFileInfo + +func (fs fakeFS) Open(name string) (File, error) { + name = path.Clean(name) + f, ok := fs[name] + if !ok { + return nil, os.ErrNotExist + } + return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil +} + +func TestDirectoryIfNotModified(t *testing.T) { + defer afterTest(t) + const indexContents = "I am a fake index.html file" + fileMod := time.Unix(1000000000, 0).UTC() + fileModStr := fileMod.Format(TimeFormat) + dirMod := time.Unix(123, 0).UTC() + indexFile := &fakeFileInfo{ + basename: "index.html", + modtime: fileMod, + contents: indexContents, + } + fs := fakeFS{ + "/": &fakeFileInfo{ + dir: true, + modtime: dirMod, + ents: []*fakeFileInfo{indexFile}, + }, + "/index.html": indexFile, + } + + ts := httptest.NewServer(FileServer(fs)) + defer ts.Close() + + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(b) != indexContents { + t.Fatalf("Got body %q; want %q", b, indexContents) + } + res.Body.Close() + + lastMod := res.Header.Get("Last-Modified") + if lastMod != fileModStr { + t.Fatalf("initial Last-Modified = %q; want %q", lastMod, fileModStr) + } + + req, _ := NewRequest("GET", ts.URL, nil) + req.Header.Set("If-Modified-Since", lastMod) + + res, err = DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 304 { + t.Fatalf("Code after If-Modified-Since request = %v; want 304", res.StatusCode) + } + res.Body.Close() + + // Advance the index.html file's modtime, but not the directory's. + indexFile.modtime = indexFile.modtime.Add(1 * time.Hour) + + res, err = DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 200 { + t.Fatalf("Code after second If-Modified-Since request = %v; want 200; res is %#v", res.StatusCode, res) + } + res.Body.Close() +} + +func mustStat(t *testing.T, fileName string) os.FileInfo { + fi, err := os.Stat(fileName) + if err != nil { + t.Fatal(err) + } + return fi +} + +func TestServeContent(t *testing.T) { + defer afterTest(t) + type serveParam struct { + name string + modtime time.Time + content io.ReadSeeker + contentType string + etag string + } + servec := make(chan serveParam, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + p := <-servec + if p.etag != "" { + w.Header().Set("ETag", p.etag) + } + if p.contentType != "" { + w.Header().Set("Content-Type", p.contentType) + } + ServeContent(w, r, p.name, p.modtime, p.content) + })) + defer ts.Close() + + type testCase struct { + // One of file or content must be set: + file string + content io.ReadSeeker + + modtime time.Time + serveETag string // optional + serveContentType string // optional + reqHeader map[string]string + wantLastMod string + wantContentType string + wantStatus int + } + htmlModTime := mustStat(t, "testdata/index.html").ModTime() + tests := map[string]testCase{ + "no_last_modified": { + file: "testdata/style.css", + wantContentType: "text/css; charset=utf-8", + wantStatus: 200, + }, + "with_last_modified": { + file: "testdata/index.html", + wantContentType: "text/html; charset=utf-8", + modtime: htmlModTime, + wantLastMod: htmlModTime.UTC().Format(TimeFormat), + wantStatus: 200, + }, + "not_modified_modtime": { + file: "testdata/style.css", + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat), + }, + wantStatus: 304, + }, + "not_modified_modtime_with_contenttype": { + file: "testdata/style.css", + serveContentType: "text/css", // explicit content type + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat), + }, + wantStatus: 304, + }, + "not_modified_etag": { + file: "testdata/style.css", + serveETag: `"foo"`, + reqHeader: map[string]string{ + "If-None-Match": `"foo"`, + }, + wantStatus: 304, + }, + "not_modified_etag_no_seek": { + content: panicOnSeek{nil}, // should never be called + serveETag: `"foo"`, + reqHeader: map[string]string{ + "If-None-Match": `"foo"`, + }, + wantStatus: 304, + }, + "range_good": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + }, + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + }, + // An If-Range resource for entity "A", but entity "B" is now current. + // The Range request should be ignored. + "range_no_match": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": `"B"`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, + "range_with_modtime": { + file: "testdata/style.css", + modtime: time.Date(2014, 6, 25, 17, 12, 18, 0 /* nanos */, time.UTC), + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": "Wed, 25 Jun 2014 17:12:18 GMT", + }, + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", + }, + "range_with_modtime_nanos": { + file: "testdata/style.css", + modtime: time.Date(2014, 6, 25, 17, 12, 18, 123 /* nanos */, time.UTC), + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": "Wed, 25 Jun 2014 17:12:18 GMT", + }, + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", + }, + } + for testName, tt := range tests { + var content io.ReadSeeker + if tt.file != "" { + f, err := os.Open(tt.file) + if err != nil { + t.Fatalf("test %q: %v", testName, err) + } + defer f.Close() + content = f + } else { + content = tt.content + } + + servec <- serveParam{ + name: filepath.Base(tt.file), + content: content, + modtime: tt.modtime, + etag: tt.serveETag, + contentType: tt.serveContentType, + } + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + for k, v := range tt.reqHeader { + req.Header.Set(k, v) + } + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + if res.StatusCode != tt.wantStatus { + t.Errorf("test %q: status = %d; want %d", testName, res.StatusCode, tt.wantStatus) + } + if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e { + t.Errorf("test %q: content-type = %q, want %q", testName, g, e) + } + if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e { + t.Errorf("test %q: last-modified = %q, want %q", testName, g, e) + } + } +} + +// verifies that sendfile is being used on Linux +func TestLinuxSendfile(t *testing.T) { + defer afterTest(t) + if runtime.GOOS != "linux" { + t.Skip("skipping; linux-only test") + } + if _, err := exec.LookPath("strace"); err != nil { + t.Skip("skipping; strace not found in path") + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + lnf, err := ln.(*net.TCPListener).File() + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + var buf bytes.Buffer + child := exec.Command("strace", "-f", "-q", "-e", "trace=sendfile,sendfile64", os.Args[0], "-test.run=TestLinuxSendfileChild") + child.ExtraFiles = append(child.ExtraFiles, lnf) + child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...) + child.Stdout = &buf + child.Stderr = &buf + if err := child.Start(); err != nil { + t.Skipf("skipping; failed to start straced child: %v", err) + } + + res, err := Get(fmt.Sprintf("http://%s/", ln.Addr())) + if err != nil { + t.Fatalf("http client error: %v", err) + } + _, err = io.Copy(ioutil.Discard, res.Body) + if err != nil { + t.Fatalf("client body read error: %v", err) + } + res.Body.Close() + + // Force child to exit cleanly. + Get(fmt.Sprintf("http://%s/quit", ln.Addr())) + child.Wait() + + rx := regexp.MustCompile(`sendfile(64)?\(\d+,\s*\d+,\s*NULL,\s*\d+\)\s*=\s*\d+\s*\n`) + rxResume := regexp.MustCompile(`<\.\.\. sendfile(64)? resumed> \)\s*=\s*\d+\s*\n`) + out := buf.String() + if !rx.MatchString(out) && !rxResume.MatchString(out) { + t.Errorf("no sendfile system call found in:\n%s", out) + } +} + +func getBody(t *testing.T, testName string, req Request) (*Response, []byte) { + r, err := DefaultClient.Do(&req) + if err != nil { + t.Fatalf("%s: for URL %q, send error: %v", testName, req.URL.String(), err) + } + b, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatalf("%s: for URL %q, reading body: %v", testName, req.URL.String(), err) + } + return r, b +} + +// TestLinuxSendfileChild isn't a real test. It's used as a helper process +// for TestLinuxSendfile. +func TestLinuxSendfileChild(*testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + defer os.Exit(0) + fd3 := os.NewFile(3, "ephemeral-port-listener") + ln, err := net.FileListener(fd3) + if err != nil { + panic(err) + } + mux := NewServeMux() + mux.Handle("/", FileServer(Dir("testdata"))) + mux.HandleFunc("/quit", func(ResponseWriter, *Request) { + os.Exit(0) + }) + s := &Server{Handler: mux} + err = s.Serve(ln) + if err != nil { + panic(err) + } +} + +func TestFileServerCleanPath(t *testing.T) { + tests := []struct { + path string + wantCode int + wantOpen []string + }{ + {"/", 200, []string{"/", "/index.html"}}, + {"/dir", 301, []string{"/dir"}}, + {"/dir/", 200, []string{"/dir", "/dir/index.html"}}, + } + for _, tt := range tests { + var log []string + rr := httptest.NewRecorder() + req, _ := NewRequest("GET", "http://foo.localhost"+tt.path, nil) + FileServer(fileServerCleanPathDir{&log}).ServeHTTP(rr, req) + if !reflect.DeepEqual(log, tt.wantOpen) { + t.Logf("For %s: Opens = %q; want %q", tt.path, log, tt.wantOpen) + } + if rr.Code != tt.wantCode { + t.Logf("For %s: Response code = %d; want %d", tt.path, rr.Code, tt.wantCode) + } + } +} + +type fileServerCleanPathDir struct { + log *[]string +} + +func (d fileServerCleanPathDir) Open(path string) (File, error) { + *(d.log) = append(*(d.log), path) + if path == "/" || path == "/dir" || path == "/dir/" { + // Just return back something that's a directory. + return Dir(".").Open(".") + } + return nil, os.ErrNotExist +} + +type panicOnSeek struct{ io.ReadSeeker } diff --git a/src/net/http/header.go b/src/net/http/header.go new file mode 100644 index 000000000..153b94370 --- /dev/null +++ b/src/net/http/header.go @@ -0,0 +1,211 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "io" + "net/textproto" + "sort" + "strings" + "sync" + "time" +) + +var raceEnabled = false // set by race.go + +// A Header represents the key-value pairs in an HTTP header. +type Header map[string][]string + +// Add adds the key, value pair to the header. +// It appends to any existing values associated with key. +func (h Header) Add(key, value string) { + textproto.MIMEHeader(h).Add(key, value) +} + +// Set sets the header entries associated with key to +// the single element value. It replaces any existing +// values associated with key. +func (h Header) Set(key, value string) { + textproto.MIMEHeader(h).Set(key, value) +} + +// Get gets the first value associated with the given key. +// If there are no values associated with the key, Get returns "". +// To access multiple values of a key, access the map directly +// with CanonicalHeaderKey. +func (h Header) Get(key string) string { + return textproto.MIMEHeader(h).Get(key) +} + +// get is like Get, but key must already be in CanonicalHeaderKey form. +func (h Header) get(key string) string { + if v := h[key]; len(v) > 0 { + return v[0] + } + return "" +} + +// Del deletes the values associated with key. +func (h Header) Del(key string) { + textproto.MIMEHeader(h).Del(key) +} + +// Write writes a header in wire format. +func (h Header) Write(w io.Writer) error { + return h.WriteSubset(w, nil) +} + +func (h Header) clone() Header { + h2 := make(Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + +var timeFormats = []string{ + TimeFormat, + time.RFC850, + time.ANSIC, +} + +// ParseTime parses a time header (such as the Date: header), +// trying each of the three formats allowed by HTTP/1.1: +// TimeFormat, time.RFC850, and time.ANSIC. +func ParseTime(text string) (t time.Time, err error) { + for _, layout := range timeFormats { + t, err = time.Parse(layout, text) + if err == nil { + return + } + } + return +} + +var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") + +type writeStringer interface { + WriteString(string) (int, error) +} + +// stringWriter implements WriteString on a Writer. +type stringWriter struct { + w io.Writer +} + +func (w stringWriter) WriteString(s string) (n int, err error) { + return w.w.Write([]byte(s)) +} + +type keyValues struct { + key string + values []string +} + +// A headerSorter implements sort.Interface by sorting a []keyValues +// by key. It's used as a pointer, so it can fit in a sort.Interface +// interface value without allocation. +type headerSorter struct { + kvs []keyValues +} + +func (s *headerSorter) Len() int { return len(s.kvs) } +func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } + +var headerSorterPool = sync.Pool{ + New: func() interface{} { return new(headerSorter) }, +} + +// sortedKeyValues returns h's keys sorted in the returned kvs +// slice. The headerSorter used to sort is also returned, for possible +// return to headerSorterCache. +func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { + hs = headerSorterPool.Get().(*headerSorter) + if cap(hs.kvs) < len(h) { + hs.kvs = make([]keyValues, 0, len(h)) + } + kvs = hs.kvs[:0] + for k, vv := range h { + if !exclude[k] { + kvs = append(kvs, keyValues{k, vv}) + } + } + hs.kvs = kvs + sort.Sort(hs) + return kvs, hs +} + +// WriteSubset writes a header in wire format. +// If exclude is not nil, keys where exclude[key] == true are not written. +func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { + ws, ok := w.(writeStringer) + if !ok { + ws = stringWriter{w} + } + kvs, sorter := h.sortedKeyValues(exclude) + for _, kv := range kvs { + for _, v := range kv.values { + v = headerNewlineToSpace.Replace(v) + v = textproto.TrimString(v) + for _, s := range []string{kv.key, ": ", v, "\r\n"} { + if _, err := ws.WriteString(s); err != nil { + return err + } + } + } + } + headerSorterPool.Put(sorter) + return nil +} + +// CanonicalHeaderKey returns the canonical format of the +// header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } + +// hasToken reports whether token appears with v, ASCII +// case-insensitive, with space or comma boundaries. +// token must be all lowercase. +// v may contain mixed cased. +func hasToken(v, token string) bool { + if len(token) > len(v) || token == "" { + return false + } + if v == token { + return true + } + for sp := 0; sp <= len(v)-len(token); sp++ { + // Check that first character is good. + // The token is ASCII, so checking only a single byte + // is sufficient. We skip this potential starting + // position if both the first byte and its potential + // ASCII uppercase equivalent (b|0x20) don't match. + // False positives ('^' => '~') are caught by EqualFold. + if b := v[sp]; b != token[0] && b|0x20 != token[0] { + continue + } + // Check that start pos is on a valid token boundary. + if sp > 0 && !isTokenBoundary(v[sp-1]) { + continue + } + // Check that end pos is on a valid token boundary. + if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) { + continue + } + if strings.EqualFold(v[sp:sp+len(token)], token) { + return true + } + } + return false +} + +func isTokenBoundary(b byte) bool { + return b == ' ' || b == ',' || b == '\t' +} diff --git a/src/net/http/header_test.go b/src/net/http/header_test.go new file mode 100644 index 000000000..9dcd591fa --- /dev/null +++ b/src/net/http/header_test.go @@ -0,0 +1,212 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "runtime" + "testing" + "time" +) + +var headerWriteTests = []struct { + h Header + exclude map[string]bool + expected string +}{ + {Header{}, nil, ""}, + { + Header{ + "Content-Type": {"text/html; charset=UTF-8"}, + "Content-Length": {"0"}, + }, + nil, + "Content-Length: 0\r\nContent-Type: text/html; charset=UTF-8\r\n", + }, + { + Header{ + "Content-Length": {"0", "1", "2"}, + }, + nil, + "Content-Length: 0\r\nContent-Length: 1\r\nContent-Length: 2\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true}, + "Content-Encoding: gzip\r\nExpires: -1\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0", "1", "2"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true}, + "Content-Encoding: gzip\r\nExpires: -1\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true, "Expires": true, "Content-Encoding": true}, + "", + }, + { + Header{ + "Nil": nil, + "Empty": {}, + "Blank": {""}, + "Double-Blank": {"", ""}, + }, + nil, + "Blank: \r\nDouble-Blank: \r\nDouble-Blank: \r\n", + }, + // Tests header sorting when over the insertion sort threshold side: + { + Header{ + "k1": {"1a", "1b"}, + "k2": {"2a", "2b"}, + "k3": {"3a", "3b"}, + "k4": {"4a", "4b"}, + "k5": {"5a", "5b"}, + "k6": {"6a", "6b"}, + "k7": {"7a", "7b"}, + "k8": {"8a", "8b"}, + "k9": {"9a", "9b"}, + }, + map[string]bool{"k5": true}, + "k1: 1a\r\nk1: 1b\r\nk2: 2a\r\nk2: 2b\r\nk3: 3a\r\nk3: 3b\r\n" + + "k4: 4a\r\nk4: 4b\r\nk6: 6a\r\nk6: 6b\r\n" + + "k7: 7a\r\nk7: 7b\r\nk8: 8a\r\nk8: 8b\r\nk9: 9a\r\nk9: 9b\r\n", + }, +} + +func TestHeaderWrite(t *testing.T) { + var buf bytes.Buffer + for i, test := range headerWriteTests { + test.h.WriteSubset(&buf, test.exclude) + if buf.String() != test.expected { + t.Errorf("#%d:\n got: %q\nwant: %q", i, buf.String(), test.expected) + } + buf.Reset() + } +} + +var parseTimeTests = []struct { + h Header + err bool +}{ + {Header{"Date": {""}}, true}, + {Header{"Date": {"invalid"}}, true}, + {Header{"Date": {"1994-11-06T08:49:37Z00:00"}}, true}, + {Header{"Date": {"Sun, 06 Nov 1994 08:49:37 GMT"}}, false}, + {Header{"Date": {"Sunday, 06-Nov-94 08:49:37 GMT"}}, false}, + {Header{"Date": {"Sun Nov 6 08:49:37 1994"}}, false}, +} + +func TestParseTime(t *testing.T) { + expect := time.Date(1994, 11, 6, 8, 49, 37, 0, time.UTC) + for i, test := range parseTimeTests { + d, err := ParseTime(test.h.Get("Date")) + if err != nil { + if !test.err { + t.Errorf("#%d:\n got err: %v", i, err) + } + continue + } + if test.err { + t.Errorf("#%d:\n should err", i) + continue + } + if !expect.Equal(d) { + t.Errorf("#%d:\n got: %v\nwant: %v", i, d, expect) + } + } +} + +type hasTokenTest struct { + header string + token string + want bool +} + +var hasTokenTests = []hasTokenTest{ + {"", "", false}, + {"", "foo", false}, + {"foo", "foo", true}, + {"foo ", "foo", true}, + {" foo", "foo", true}, + {" foo ", "foo", true}, + {"foo,bar", "foo", true}, + {"bar,foo", "foo", true}, + {"bar, foo", "foo", true}, + {"bar,foo, baz", "foo", true}, + {"bar, foo,baz", "foo", true}, + {"bar,foo, baz", "foo", true}, + {"bar, foo, baz", "foo", true}, + {"FOO", "foo", true}, + {"FOO ", "foo", true}, + {" FOO", "foo", true}, + {" FOO ", "foo", true}, + {"FOO,BAR", "foo", true}, + {"BAR,FOO", "foo", true}, + {"BAR, FOO", "foo", true}, + {"BAR,FOO, baz", "foo", true}, + {"BAR, FOO,BAZ", "foo", true}, + {"BAR,FOO, BAZ", "foo", true}, + {"BAR, FOO, BAZ", "foo", true}, + {"foobar", "foo", false}, + {"barfoo ", "foo", false}, +} + +func TestHasToken(t *testing.T) { + for _, tt := range hasTokenTests { + if hasToken(tt.header, tt.token) != tt.want { + t.Errorf("hasToken(%q, %q) = %v; want %v", tt.header, tt.token, !tt.want, tt.want) + } + } +} + +var testHeader = Header{ + "Content-Length": {"123"}, + "Content-Type": {"text/plain"}, + "Date": {"some date at some time Z"}, + "Server": {DefaultUserAgent}, +} + +var buf bytes.Buffer + +func BenchmarkHeaderWriteSubset(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf.Reset() + testHeader.WriteSubset(&buf, nil) + } +} + +func TestHeaderWriteSubsetAllocs(t *testing.T) { + if testing.Short() { + t.Skip("skipping alloc test in short mode") + } + if raceEnabled { + t.Skip("skipping test under race detector") + } + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") + } + n := testing.AllocsPerRun(100, func() { + buf.Reset() + testHeader.WriteSubset(&buf, nil) + }) + if n > 0 { + t.Errorf("allocs = %g; want 0", n) + } +} diff --git a/src/net/http/httptest/example_test.go b/src/net/http/httptest/example_test.go new file mode 100644 index 000000000..42a0ec953 --- /dev/null +++ b/src/net/http/httptest/example_test.go @@ -0,0 +1,50 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest_test + +import ( + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" +) + +func ExampleResponseRecorder() { + handler := func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "something failed", http.StatusInternalServerError) + } + + req, err := http.NewRequest("GET", "http://example.com/foo", nil) + if err != nil { + log.Fatal(err) + } + + w := httptest.NewRecorder() + handler(w, req) + + fmt.Printf("%d - %s", w.Code, w.Body.String()) + // Output: 500 - something failed +} + +func ExampleServer() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + defer ts.Close() + + res, err := http.Get(ts.URL) + if err != nil { + log.Fatal(err) + } + greeting, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", greeting) + // Output: Hello, client +} diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go new file mode 100644 index 000000000..5451f5423 --- /dev/null +++ b/src/net/http/httptest/recorder.go @@ -0,0 +1,72 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httptest provides utilities for HTTP testing. +package httptest + +import ( + "bytes" + "net/http" +) + +// ResponseRecorder is an implementation of http.ResponseWriter that +// records its mutations for later inspection in tests. +type ResponseRecorder struct { + Code int // the HTTP response code from WriteHeader + HeaderMap http.Header // the HTTP response headers + Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to + Flushed bool + + wroteHeader bool +} + +// NewRecorder returns an initialized ResponseRecorder. +func NewRecorder() *ResponseRecorder { + return &ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + Code: 200, + } +} + +// DefaultRemoteAddr is the default remote address to return in RemoteAddr if +// an explicit DefaultRemoteAddr isn't set on ResponseRecorder. +const DefaultRemoteAddr = "1.2.3.4" + +// Header returns the response headers. +func (rw *ResponseRecorder) Header() http.Header { + m := rw.HeaderMap + if m == nil { + m = make(http.Header) + rw.HeaderMap = m + } + return m +} + +// Write always succeeds and writes to rw.Body, if not nil. +func (rw *ResponseRecorder) Write(buf []byte) (int, error) { + if !rw.wroteHeader { + rw.WriteHeader(200) + } + if rw.Body != nil { + rw.Body.Write(buf) + } + return len(buf), nil +} + +// WriteHeader sets rw.Code. +func (rw *ResponseRecorder) WriteHeader(code int) { + if !rw.wroteHeader { + rw.Code = code + } + rw.wroteHeader = true +} + +// Flush sets rw.Flushed to true. +func (rw *ResponseRecorder) Flush() { + if !rw.wroteHeader { + rw.WriteHeader(200) + } + rw.Flushed = true +} diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go new file mode 100644 index 000000000..2b563260c --- /dev/null +++ b/src/net/http/httptest/recorder_test.go @@ -0,0 +1,90 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "fmt" + "net/http" + "testing" +) + +func TestRecorder(t *testing.T) { + type checkFunc func(*ResponseRecorder) error + check := func(fns ...checkFunc) []checkFunc { return fns } + + hasStatus := func(wantCode int) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Code != wantCode { + return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode) + } + return nil + } + } + hasContents := func(want string) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Body.String() != want { + return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want) + } + return nil + } + } + hasFlush := func(want bool) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Flushed != want { + return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want) + } + return nil + } + } + + tests := []struct { + name string + h func(w http.ResponseWriter, r *http.Request) + checks []checkFunc + }{ + { + "200 default", + func(w http.ResponseWriter, r *http.Request) {}, + check(hasStatus(200), hasContents("")), + }, + { + "first code only", + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(201) + w.WriteHeader(202) + w.Write([]byte("hi")) + }, + check(hasStatus(201), hasContents("hi")), + }, + { + "write sends 200", + func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi first")) + w.WriteHeader(201) + w.WriteHeader(202) + }, + check(hasStatus(200), hasContents("hi first"), hasFlush(false)), + }, + { + "flush", + func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() // also sends a 200 + w.WriteHeader(201) + }, + check(hasStatus(200), hasFlush(true)), + }, + } + r, _ := http.NewRequest("GET", "http://foo.com/", nil) + for _, tt := range tests { + h := http.HandlerFunc(tt.h) + rec := NewRecorder() + h.ServeHTTP(rec, r) + for _, check := range tt.checks { + if err := check(rec); err != nil { + t.Errorf("%s: %v", tt.name, err) + } + } + } +} diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go new file mode 100644 index 000000000..789e7bf41 --- /dev/null +++ b/src/net/http/httptest/server.go @@ -0,0 +1,228 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Implementation of Server + +package httptest + +import ( + "crypto/tls" + "flag" + "fmt" + "net" + "net/http" + "os" + "sync" +) + +// A Server is an HTTP server listening on a system-chosen port on the +// local loopback interface, for use in end-to-end HTTP tests. +type Server struct { + URL string // base URL of form http://ipaddr:port with no trailing slash + Listener net.Listener + + // TLS is the optional TLS configuration, populated with a new config + // after TLS is started. If set on an unstarted server before StartTLS + // is called, existing fields are copied into the new config. + TLS *tls.Config + + // Config may be changed after calling NewUnstartedServer and + // before Start or StartTLS. + Config *http.Server + + // wg counts the number of outstanding HTTP requests on this server. + // Close blocks until all requests are finished. + wg sync.WaitGroup +} + +// historyListener keeps track of all connections that it's ever +// accepted. +type historyListener struct { + net.Listener + sync.Mutex // protects history + history []net.Conn +} + +func (hs *historyListener) Accept() (c net.Conn, err error) { + c, err = hs.Listener.Accept() + if err == nil { + hs.Lock() + hs.history = append(hs.history, c) + hs.Unlock() + } + return +} + +func newLocalListener() net.Listener { + if *serve != "" { + l, err := net.Listen("tcp", *serve) + if err != nil { + panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err)) + } + return l + } + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err)) + } + } + return l +} + +// When debugging a particular http server-based test, +// this flag lets you run +// go test -run=BrokenTest -httptest.serve=127.0.0.1:8000 +// to start the broken server so you can interact with it manually. +var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks") + +// NewServer starts and returns a new Server. +// The caller should call Close when finished, to shut it down. +func NewServer(handler http.Handler) *Server { + ts := NewUnstartedServer(handler) + ts.Start() + return ts +} + +// NewUnstartedServer returns a new Server but doesn't start it. +// +// After changing its configuration, the caller should call Start or +// StartTLS. +// +// The caller should call Close when finished, to shut it down. +func NewUnstartedServer(handler http.Handler) *Server { + return &Server{ + Listener: newLocalListener(), + Config: &http.Server{Handler: handler}, + } +} + +// Start starts a server from NewUnstartedServer. +func (s *Server) Start() { + if s.URL != "" { + panic("Server already started") + } + s.Listener = &historyListener{Listener: s.Listener} + s.URL = "http://" + s.Listener.Addr().String() + s.wrapHandler() + go s.Config.Serve(s.Listener) + if *serve != "" { + fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL) + select {} + } +} + +// StartTLS starts TLS on a server from NewUnstartedServer. +func (s *Server) StartTLS() { + if s.URL != "" { + panic("Server already started") + } + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) + } + + existingConfig := s.TLS + s.TLS = new(tls.Config) + if existingConfig != nil { + *s.TLS = *existingConfig + } + if s.TLS.NextProtos == nil { + s.TLS.NextProtos = []string{"http/1.1"} + } + if len(s.TLS.Certificates) == 0 { + s.TLS.Certificates = []tls.Certificate{cert} + } + tlsListener := tls.NewListener(s.Listener, s.TLS) + + s.Listener = &historyListener{Listener: tlsListener} + s.URL = "https://" + s.Listener.Addr().String() + s.wrapHandler() + go s.Config.Serve(s.Listener) +} + +func (s *Server) wrapHandler() { + h := s.Config.Handler + if h == nil { + h = http.DefaultServeMux + } + s.Config.Handler = &waitGroupHandler{ + s: s, + h: h, + } +} + +// NewTLSServer starts and returns a new Server using TLS. +// The caller should call Close when finished, to shut it down. +func NewTLSServer(handler http.Handler) *Server { + ts := NewUnstartedServer(handler) + ts.StartTLS() + return ts +} + +// Close shuts down the server and blocks until all outstanding +// requests on this server have completed. +func (s *Server) Close() { + s.Listener.Close() + s.wg.Wait() + s.CloseClientConnections() + if t, ok := http.DefaultTransport.(*http.Transport); ok { + t.CloseIdleConnections() + } +} + +// CloseClientConnections closes any currently open HTTP connections +// to the test Server. +func (s *Server) CloseClientConnections() { + hl, ok := s.Listener.(*historyListener) + if !ok { + return + } + hl.Lock() + for _, conn := range hl.history { + conn.Close() + } + hl.Unlock() +} + +// waitGroupHandler wraps a handler, incrementing and decrementing a +// sync.WaitGroup on each request, to enable Server.Close to block +// until outstanding requests are finished. +type waitGroupHandler struct { + s *Server + h http.Handler // non-nil +} + +func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.s.wg.Add(1) + defer h.s.wg.Done() // a defer, in case ServeHTTP below panics + h.h.ServeHTTP(w, r) +} + +// localhostCert is a PEM-encoded TLS cert with SAN IPs +// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end +// of ASN.1 time). +// generated from src/crypto/tls: +// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +var localhostCert = []byte(`-----BEGIN CERTIFICATE----- +MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD +bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj +bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBAN55NcYKZeInyTuhcCwFMhDHCmwa +IUSdtXdcbItRB/yfXGBhiex00IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEA +AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud +EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA +AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAAoQn/ytgqpiLcZu9XKbCJsJcvkgk +Se6AbGXgSlq+ZCEVo0qIwSgeBqmsJxUu7NCSOwVJLYNEBO2DtIxoYVk+MA== +-----END CERTIFICATE-----`) + +// localhostKey is the private key for localhostCert. +var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIBPAIBAAJBAN55NcYKZeInyTuhcCwFMhDHCmwaIUSdtXdcbItRB/yfXGBhiex0 +0IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEAAQJBAQdUx66rfh8sYsgfdcvV +NoafYpnEcB5s4m/vSVe6SU7dCK6eYec9f9wpT353ljhDUHq3EbmE4foNzJngh35d +AekCIQDhRQG5Li0Wj8TM4obOnnXUXf1jRv0UkzE9AHWLG5q3AwIhAPzSjpYUDjVW +MCUXgckTpKCuGwbJk7424Nb8bLzf3kllAiA5mUBgjfr/WtFSJdWcPQ4Zt9KTMNKD +EUO0ukpTwEIl6wIhAMbGqZK3zAAFdq8DD2jPx+UJXnh0rnOkZBzDtJ6/iN69AiEA +1Aq8MJgTaYsDQWyU/hDq5YkDJc9e9DSCvUIzqxQWMQE= +-----END RSA PRIVATE KEY-----`) diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go new file mode 100644 index 000000000..500a9f0b8 --- /dev/null +++ b/src/net/http/httptest/server_test.go @@ -0,0 +1,29 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "io/ioutil" + "net/http" + "testing" +) + +func TestServer(t *testing.T) { + ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + defer ts.Close() + res, err := http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + got, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(got) != "hello" { + t.Errorf("got %q, want hello", string(got)) + } +} diff --git a/src/net/http/httputil/dump.go b/src/net/http/httputil/dump.go new file mode 100644 index 000000000..ac8f103f9 --- /dev/null +++ b/src/net/http/httputil/dump.go @@ -0,0 +1,284 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + "time" +) + +// One of the copies, say from b to r2, could be avoided by using a more +// elaborate trick where the other copy is made during Request/Response.Write. +// This would complicate things too much, given that these functions are for +// debugging only. +func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { + var buf bytes.Buffer + if _, err = buf.ReadFrom(b); err != nil { + return nil, nil, err + } + if err = b.Close(); err != nil { + return nil, nil, err + } + return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil +} + +// dumpConn is a net.Conn which writes to Writer and reads from Reader +type dumpConn struct { + io.Writer + io.Reader +} + +func (c *dumpConn) Close() error { return nil } +func (c *dumpConn) LocalAddr() net.Addr { return nil } +func (c *dumpConn) RemoteAddr() net.Addr { return nil } +func (c *dumpConn) SetDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +// DumpRequestOut is like DumpRequest but includes +// headers that the standard http.Transport adds, +// such as User-Agent. +func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { + save := req.Body + dummyBody := false + if !body || req.Body == nil { + req.Body = nil + if req.ContentLength != 0 { + req.Body = ioutil.NopCloser(io.LimitReader(neverEnding('x'), req.ContentLength)) + dummyBody = true + } + } else { + var err error + save, req.Body, err = drainBody(req.Body) + if err != nil { + return nil, err + } + } + + // Since we're using the actual Transport code to write the request, + // switch to http so the Transport doesn't try to do an SSL + // negotiation with our dumpConn and its bytes.Buffer & pipe. + // The wire format for https and http are the same, anyway. + reqSend := req + if req.URL.Scheme == "https" { + reqSend = new(http.Request) + *reqSend = *req + reqSend.URL = new(url.URL) + *reqSend.URL = *req.URL + reqSend.URL.Scheme = "http" + } + + // Use the actual Transport code to record what we would send + // on the wire, but not using TCP. Use a Transport with a + // custom dialer that returns a fake net.Conn that waits + // for the full input (and recording it), and then responds + // with a dummy response. + var buf bytes.Buffer // records the output + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + dr := &delegateReader{c: make(chan io.Reader)} + // Wait for the request before replying with a dummy response: + go func() { + req, err := http.ReadRequest(bufio.NewReader(pr)) + if err == nil { + // Ensure all the body is read; otherwise + // we'll get a partial dump. + io.Copy(ioutil.Discard, req.Body) + req.Body.Close() + } + dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\n\r\n") + }() + + t := &http.Transport{ + DisableKeepAlives: true, + Dial: func(net, addr string) (net.Conn, error) { + return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil + }, + } + + _, err := t.RoundTrip(reqSend) + + req.Body = save + if err != nil { + return nil, err + } + dump := buf.Bytes() + + // If we used a dummy body above, remove it now. + // TODO: if the req.ContentLength is large, we allocate memory + // unnecessarily just to slice it off here. But this is just + // a debug function, so this is acceptable for now. We could + // discard the body earlier if this matters. + if dummyBody { + if i := bytes.Index(dump, []byte("\r\n\r\n")); i >= 0 { + dump = dump[:i+4] + } + } + return dump, nil +} + +// delegateReader is a reader that delegates to another reader, +// once it arrives on a channel. +type delegateReader struct { + c chan io.Reader + r io.Reader // nil until received from c +} + +func (r *delegateReader) Read(p []byte) (int, error) { + if r.r == nil { + r.r = <-r.c + } + return r.r.Read(p) +} + +// Return value if nonempty, def otherwise. +func valueOrDefault(value, def string) string { + if value != "" { + return value + } + return def +} + +var reqWriteExcludeHeaderDump = map[string]bool{ + "Host": true, // not in Header map anyway + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + +// dumpAsReceived writes req to w in the form as it was received, or +// at least as accurately as possible from the information retained in +// the request. +func dumpAsReceived(req *http.Request, w io.Writer) error { + return nil +} + +// DumpRequest returns the as-received wire representation of req, +// optionally including the request body, for debugging. +// DumpRequest is semantically a no-op, but in order to +// dump the body, it reads the body data into memory and +// changes req.Body to refer to the in-memory copy. +// The documentation for http.Request.Write details which fields +// of req are used. +func DumpRequest(req *http.Request, body bool) (dump []byte, err error) { + save := req.Body + if !body || req.Body == nil { + req.Body = nil + } else { + save, req.Body, err = drainBody(req.Body) + if err != nil { + return + } + } + + var b bytes.Buffer + + fmt.Fprintf(&b, "%s %s HTTP/%d.%d\r\n", valueOrDefault(req.Method, "GET"), + req.URL.RequestURI(), req.ProtoMajor, req.ProtoMinor) + + host := req.Host + if host == "" && req.URL != nil { + host = req.URL.Host + } + if host != "" { + fmt.Fprintf(&b, "Host: %s\r\n", host) + } + + chunked := len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" + if len(req.TransferEncoding) > 0 { + fmt.Fprintf(&b, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ",")) + } + if req.Close { + fmt.Fprintf(&b, "Connection: close\r\n") + } + + err = req.Header.WriteSubset(&b, reqWriteExcludeHeaderDump) + if err != nil { + return + } + + io.WriteString(&b, "\r\n") + + if req.Body != nil { + var dest io.Writer = &b + if chunked { + dest = NewChunkedWriter(dest) + } + _, err = io.Copy(dest, req.Body) + if chunked { + dest.(io.Closer).Close() + io.WriteString(&b, "\r\n") + } + } + + req.Body = save + if err != nil { + return + } + dump = b.Bytes() + return +} + +// errNoBody is a sentinel error value used by failureToReadBody so we can detect +// that the lack of body was intentional. +var errNoBody = errors.New("sentinel error value") + +// failureToReadBody is a io.ReadCloser that just returns errNoBody on +// Read. It's swapped in when we don't actually want to consume the +// body, but need a non-nil one, and want to distinguish the error +// from reading the dummy body. +type failureToReadBody struct{} + +func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody } +func (failureToReadBody) Close() error { return nil } + +var emptyBody = ioutil.NopCloser(strings.NewReader("")) + +// DumpResponse is like DumpRequest but dumps a response. +func DumpResponse(resp *http.Response, body bool) (dump []byte, err error) { + var b bytes.Buffer + save := resp.Body + savecl := resp.ContentLength + + if !body { + resp.Body = failureToReadBody{} + } else if resp.Body == nil { + resp.Body = emptyBody + } else { + save, resp.Body, err = drainBody(resp.Body) + if err != nil { + return + } + } + err = resp.Write(&b) + if err == errNoBody { + err = nil + } + resp.Body = save + resp.ContentLength = savecl + if err != nil { + return nil, err + } + return b.Bytes(), nil +} diff --git a/src/net/http/httputil/dump_test.go b/src/net/http/httputil/dump_test.go new file mode 100644 index 000000000..024ee5a86 --- /dev/null +++ b/src/net/http/httputil/dump_test.go @@ -0,0 +1,291 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "runtime" + "strings" + "testing" +) + +type dumpTest struct { + Req http.Request + Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body + + WantDump string + WantDumpOut string + NoBody bool // if true, set DumpRequest{,Out} body to false +} + +var dumpTests = []dumpTest{ + + // HTTP/1.1 => chunked coding; body; empty trailer + { + Req: http.Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + TransferEncoding: []string{"chunked"}, + }, + + Body: []byte("abcdef"), + + WantDump: "GET /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("abcdef") + chunk(""), + }, + + // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host, + // and doesn't add a User-Agent. + { + Req: http.Request{ + Method: "GET", + URL: mustParseURL("/foo"), + ProtoMajor: 1, + ProtoMinor: 0, + Header: http.Header{ + "X-Foo": []string{"X-Bar"}, + }, + }, + + WantDump: "GET /foo HTTP/1.0\r\n" + + "X-Foo: X-Bar\r\n\r\n", + }, + + { + Req: *mustNewRequest("GET", "http://example.com/foo", nil), + + WantDumpOut: "GET /foo HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, + + // Test that an https URL doesn't try to do an SSL negotiation + // with a bytes.Buffer and hang with all goroutines not + // runnable. + { + Req: *mustNewRequest("GET", "https://example.com/foo", nil), + + WantDumpOut: "GET /foo HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, + + // Request with Body, but Dump requested without it. + { + Req: http.Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "http", + Host: "post.tld", + Path: "/", + }, + ContentLength: 6, + ProtoMajor: 1, + ProtoMinor: 1, + }, + + Body: []byte("abcdef"), + + WantDumpOut: "POST / HTTP/1.1\r\n" + + "Host: post.tld\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Content-Length: 6\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + + NoBody: true, + }, + + // Request with Body > 8196 (default buffer size) + { + Req: http.Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "http", + Host: "post.tld", + Path: "/", + }, + ContentLength: 8193, + ProtoMajor: 1, + ProtoMinor: 1, + }, + + Body: bytes.Repeat([]byte("a"), 8193), + + WantDumpOut: "POST / HTTP/1.1\r\n" + + "Host: post.tld\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Content-Length: 8193\r\n" + + "Accept-Encoding: gzip\r\n\r\n" + + strings.Repeat("a", 8193), + }, +} + +func TestDumpRequest(t *testing.T) { + numg0 := runtime.NumGoroutine() + for i, tt := range dumpTests { + setBody := func() { + if tt.Body == nil { + return + } + switch b := tt.Body.(type) { + case []byte: + tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b)) + case func() io.ReadCloser: + tt.Req.Body = b() + default: + t.Fatalf("Test %d: unsupported Body of %T", i, tt.Body) + } + } + setBody() + if tt.Req.Header == nil { + tt.Req.Header = make(http.Header) + } + + if tt.WantDump != "" { + setBody() + dump, err := DumpRequest(&tt.Req, !tt.NoBody) + if err != nil { + t.Errorf("DumpRequest #%d: %s", i, err) + continue + } + if string(dump) != tt.WantDump { + t.Errorf("DumpRequest %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDump, string(dump)) + continue + } + } + + if tt.WantDumpOut != "" { + setBody() + dump, err := DumpRequestOut(&tt.Req, !tt.NoBody) + if err != nil { + t.Errorf("DumpRequestOut #%d: %s", i, err) + continue + } + if string(dump) != tt.WantDumpOut { + t.Errorf("DumpRequestOut %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDumpOut, string(dump)) + continue + } + } + } + if dg := runtime.NumGoroutine() - numg0; dg > 4 { + buf := make([]byte, 4096) + buf = buf[:runtime.Stack(buf, true)] + t.Errorf("Unexpectedly large number of new goroutines: %d new: %s", dg, buf) + } +} + +func chunk(s string) string { + return fmt.Sprintf("%x\r\n%s\r\n", len(s), s) +} + +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(fmt.Sprintf("Error parsing URL %q: %v", s, err)) + } + return u +} + +func mustNewRequest(method, url string, body io.Reader) *http.Request { + req, err := http.NewRequest(method, url, body) + if err != nil { + panic(fmt.Sprintf("NewRequest(%q, %q, %p) err = %v", method, url, body, err)) + } + return req +} + +var dumpResTests = []struct { + res *http.Response + body bool + want string +}{ + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 50, + Header: http.Header{ + "Foo": []string{"Bar"}, + }, + Body: ioutil.NopCloser(strings.NewReader("foo")), // shouldn't be used + }, + body: false, // to verify we see 50, not empty or 3. + want: `HTTP/1.1 200 OK +Content-Length: 50 +Foo: Bar`, + }, + + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 3, + Body: ioutil.NopCloser(strings.NewReader("foo")), + }, + body: true, + want: `HTTP/1.1 200 OK +Content-Length: 3 + +foo`, + }, + + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: -1, + Body: ioutil.NopCloser(strings.NewReader("foo")), + TransferEncoding: []string{"chunked"}, + }, + body: true, + want: `HTTP/1.1 200 OK +Transfer-Encoding: chunked + +3 +foo +0`, + }, +} + +func TestDumpResponse(t *testing.T) { + for i, tt := range dumpResTests { + gotb, err := DumpResponse(tt.res, tt.body) + if err != nil { + t.Errorf("%d. DumpResponse = %v", i, err) + continue + } + got := string(gotb) + got = strings.TrimSpace(got) + got = strings.Replace(got, "\r", "", -1) + + if got != tt.want { + t.Errorf("%d.\nDumpResponse got:\n%s\n\nWant:\n%s\n", i, got, tt.want) + } + } +} diff --git a/src/net/http/httputil/httputil.go b/src/net/http/httputil/httputil.go new file mode 100644 index 000000000..2e523e9e2 --- /dev/null +++ b/src/net/http/httputil/httputil.go @@ -0,0 +1,39 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httputil provides HTTP utility functions, complementing the +// more common ones in the net/http package. +package httputil + +import ( + "io" + "net/http/internal" +) + +// NewChunkedReader returns a new chunkedReader that translates the data read from r +// out of HTTP "chunked" format before returning it. +// The chunkedReader returns io.EOF when the final 0-length chunk is read. +// +// NewChunkedReader is not needed by normal applications. The http package +// automatically decodes chunking when reading response bodies. +func NewChunkedReader(r io.Reader) io.Reader { + return internal.NewChunkedReader(r) +} + +// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream. +// +// NewChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using NewChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func NewChunkedWriter(w io.Writer) io.WriteCloser { + return internal.NewChunkedWriter(w) +} + +// ErrLineTooLong is returned when reading malformed chunked data +// with lines that are too long. +var ErrLineTooLong = internal.ErrLineTooLong diff --git a/src/net/http/httputil/persist.go b/src/net/http/httputil/persist.go new file mode 100644 index 000000000..987bcc96b --- /dev/null +++ b/src/net/http/httputil/persist.go @@ -0,0 +1,429 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil + +import ( + "bufio" + "errors" + "io" + "net" + "net/http" + "net/textproto" + "sync" +) + +var ( + ErrPersistEOF = &http.ProtocolError{ErrorString: "persistent connection closed"} + ErrClosed = &http.ProtocolError{ErrorString: "connection closed by user"} + ErrPipeline = &http.ProtocolError{ErrorString: "pipeline error"} +) + +// This is an API usage error - the local side is closed. +// ErrPersistEOF (above) reports that the remote side is closed. +var errClosed = errors.New("i/o operation on closed connection") + +// A ServerConn reads requests and sends responses over an underlying +// connection, until the HTTP keepalive logic commands an end. ServerConn +// also allows hijacking the underlying connection by calling Hijack +// to regain control over the connection. ServerConn supports pipe-lining, +// i.e. requests can be read out of sync (but in the same order) while the +// respective responses are sent. +// +// ServerConn is low-level and old. Applications should instead use Server +// in the net/http package. +type ServerConn struct { + lk sync.Mutex // read-write protects the following fields + c net.Conn + r *bufio.Reader + re, we error // read/write errors + lastbody io.ReadCloser + nread, nwritten int + pipereq map[*http.Request]uint + + pipe textproto.Pipeline +} + +// NewServerConn returns a new ServerConn reading and writing c. If r is not +// nil, it is the buffer to use when reading c. +// +// ServerConn is low-level and old. Applications should instead use Server +// in the net/http package. +func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { + if r == nil { + r = bufio.NewReader(c) + } + return &ServerConn{c: c, r: r, pipereq: make(map[*http.Request]uint)} +} + +// Hijack detaches the ServerConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be +// called before Read has signaled the end of the keep-alive logic. The user +// should not call Hijack while Read or Write is in progress. +func (sc *ServerConn) Hijack() (c net.Conn, r *bufio.Reader) { + sc.lk.Lock() + defer sc.lk.Unlock() + c = sc.c + r = sc.r + sc.c = nil + sc.r = nil + return +} + +// Close calls Hijack and then also closes the underlying connection +func (sc *ServerConn) Close() error { + c, _ := sc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + +// Read returns the next request on the wire. An ErrPersistEOF is returned if +// it is gracefully determined that there are no more requests (e.g. after the +// first request on an HTTP/1.0 connection, or after a Connection:close on a +// HTTP/1.1 connection). +func (sc *ServerConn) Read() (req *http.Request, err error) { + + // Ensure ordered execution of Reads and Writes + id := sc.pipe.Next() + sc.pipe.StartRequest(id) + defer func() { + sc.pipe.EndRequest(id) + if req == nil { + sc.pipe.StartResponse(id) + sc.pipe.EndResponse(id) + } else { + // Remember the pipeline id of this request + sc.lk.Lock() + sc.pipereq[req] = id + sc.lk.Unlock() + } + }() + + sc.lk.Lock() + if sc.we != nil { // no point receiving if write-side broken or closed + defer sc.lk.Unlock() + return nil, sc.we + } + if sc.re != nil { + defer sc.lk.Unlock() + return nil, sc.re + } + if sc.r == nil { // connection closed by user in the meantime + defer sc.lk.Unlock() + return nil, errClosed + } + r := sc.r + lastbody := sc.lastbody + sc.lastbody = nil + sc.lk.Unlock() + + // Make sure body is fully consumed, even if user does not call body.Close + if lastbody != nil { + // body.Close is assumed to be idempotent and multiple calls to + // it should return the error that its first invocation + // returned. + err = lastbody.Close() + if err != nil { + sc.lk.Lock() + defer sc.lk.Unlock() + sc.re = err + return nil, err + } + } + + req, err = http.ReadRequest(r) + sc.lk.Lock() + defer sc.lk.Unlock() + if err != nil { + if err == io.ErrUnexpectedEOF { + // A close from the opposing client is treated as a + // graceful close, even if there was some unparse-able + // data before the close. + sc.re = ErrPersistEOF + return nil, sc.re + } else { + sc.re = err + return req, err + } + } + sc.lastbody = req.Body + sc.nread++ + if req.Close { + sc.re = ErrPersistEOF + return req, sc.re + } + return req, err +} + +// Pending returns the number of unanswered requests +// that have been received on the connection. +func (sc *ServerConn) Pending() int { + sc.lk.Lock() + defer sc.lk.Unlock() + return sc.nread - sc.nwritten +} + +// Write writes resp in response to req. To close the connection gracefully, set the +// Response.Close field to true. Write should be considered operational until +// it returns an error, regardless of any errors returned on the Read side. +func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error { + + // Retrieve the pipeline ID of this request/response pair + sc.lk.Lock() + id, ok := sc.pipereq[req] + delete(sc.pipereq, req) + if !ok { + sc.lk.Unlock() + return ErrPipeline + } + sc.lk.Unlock() + + // Ensure pipeline order + sc.pipe.StartResponse(id) + defer sc.pipe.EndResponse(id) + + sc.lk.Lock() + if sc.we != nil { + defer sc.lk.Unlock() + return sc.we + } + if sc.c == nil { // connection closed by user in the meantime + defer sc.lk.Unlock() + return ErrClosed + } + c := sc.c + if sc.nread <= sc.nwritten { + defer sc.lk.Unlock() + return errors.New("persist server pipe count") + } + if resp.Close { + // After signaling a keep-alive close, any pipelined unread + // requests will be lost. It is up to the user to drain them + // before signaling. + sc.re = ErrPersistEOF + } + sc.lk.Unlock() + + err := resp.Write(c) + sc.lk.Lock() + defer sc.lk.Unlock() + if err != nil { + sc.we = err + return err + } + sc.nwritten++ + + return nil +} + +// A ClientConn sends request and receives headers over an underlying +// connection, while respecting the HTTP keepalive logic. ClientConn +// supports hijacking the connection calling Hijack to +// regain control of the underlying net.Conn and deal with it as desired. +// +// ClientConn is low-level and old. Applications should instead use +// Client or Transport in the net/http package. +type ClientConn struct { + lk sync.Mutex // read-write protects the following fields + c net.Conn + r *bufio.Reader + re, we error // read/write errors + lastbody io.ReadCloser + nread, nwritten int + pipereq map[*http.Request]uint + + pipe textproto.Pipeline + writeReq func(*http.Request, io.Writer) error +} + +// NewClientConn returns a new ClientConn reading and writing c. If r is not +// nil, it is the buffer to use when reading c. +// +// ClientConn is low-level and old. Applications should use Client or +// Transport in the net/http package. +func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { + if r == nil { + r = bufio.NewReader(c) + } + return &ClientConn{ + c: c, + r: r, + pipereq: make(map[*http.Request]uint), + writeReq: (*http.Request).Write, + } +} + +// NewProxyClientConn works like NewClientConn but writes Requests +// using Request's WriteProxy method. +// +// New code should not use NewProxyClientConn. See Client or +// Transport in the net/http package instead. +func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { + cc := NewClientConn(c, r) + cc.writeReq = (*http.Request).WriteProxy + return cc +} + +// Hijack detaches the ClientConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be +// called before the user or Read have signaled the end of the keep-alive +// logic. The user should not call Hijack while Read or Write is in progress. +func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) { + cc.lk.Lock() + defer cc.lk.Unlock() + c = cc.c + r = cc.r + cc.c = nil + cc.r = nil + return +} + +// Close calls Hijack and then also closes the underlying connection +func (cc *ClientConn) Close() error { + c, _ := cc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + +// Write writes a request. An ErrPersistEOF error is returned if the connection +// has been closed in an HTTP keepalive sense. If req.Close equals true, the +// keepalive connection is logically closed after this request and the opposing +// server is informed. An ErrUnexpectedEOF indicates the remote closed the +// underlying TCP connection, which is usually considered as graceful close. +func (cc *ClientConn) Write(req *http.Request) (err error) { + + // Ensure ordered execution of Writes + id := cc.pipe.Next() + cc.pipe.StartRequest(id) + defer func() { + cc.pipe.EndRequest(id) + if err != nil { + cc.pipe.StartResponse(id) + cc.pipe.EndResponse(id) + } else { + // Remember the pipeline id of this request + cc.lk.Lock() + cc.pipereq[req] = id + cc.lk.Unlock() + } + }() + + cc.lk.Lock() + if cc.re != nil { // no point sending if read-side closed or broken + defer cc.lk.Unlock() + return cc.re + } + if cc.we != nil { + defer cc.lk.Unlock() + return cc.we + } + if cc.c == nil { // connection closed by user in the meantime + defer cc.lk.Unlock() + return errClosed + } + c := cc.c + if req.Close { + // We write the EOF to the write-side error, because there + // still might be some pipelined reads + cc.we = ErrPersistEOF + } + cc.lk.Unlock() + + err = cc.writeReq(req, c) + cc.lk.Lock() + defer cc.lk.Unlock() + if err != nil { + cc.we = err + return err + } + cc.nwritten++ + + return nil +} + +// Pending returns the number of unanswered requests +// that have been sent on the connection. +func (cc *ClientConn) Pending() int { + cc.lk.Lock() + defer cc.lk.Unlock() + return cc.nwritten - cc.nread +} + +// Read reads the next response from the wire. A valid response might be +// returned together with an ErrPersistEOF, which means that the remote +// requested that this be the last request serviced. Read can be called +// concurrently with Write, but not with another Read. +func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) { + // Retrieve the pipeline ID of this request/response pair + cc.lk.Lock() + id, ok := cc.pipereq[req] + delete(cc.pipereq, req) + if !ok { + cc.lk.Unlock() + return nil, ErrPipeline + } + cc.lk.Unlock() + + // Ensure pipeline order + cc.pipe.StartResponse(id) + defer cc.pipe.EndResponse(id) + + cc.lk.Lock() + if cc.re != nil { + defer cc.lk.Unlock() + return nil, cc.re + } + if cc.r == nil { // connection closed by user in the meantime + defer cc.lk.Unlock() + return nil, errClosed + } + r := cc.r + lastbody := cc.lastbody + cc.lastbody = nil + cc.lk.Unlock() + + // Make sure body is fully consumed, even if user does not call body.Close + if lastbody != nil { + // body.Close is assumed to be idempotent and multiple calls to + // it should return the error that its first invocation + // returned. + err = lastbody.Close() + if err != nil { + cc.lk.Lock() + defer cc.lk.Unlock() + cc.re = err + return nil, err + } + } + + resp, err = http.ReadResponse(r, req) + cc.lk.Lock() + defer cc.lk.Unlock() + if err != nil { + cc.re = err + return resp, err + } + cc.lastbody = resp.Body + + cc.nread++ + + if resp.Close { + cc.re = ErrPersistEOF // don't send any more requests + return resp, cc.re + } + return resp, err +} + +// Do is convenience method that writes a request and reads a response. +func (cc *ClientConn) Do(req *http.Request) (resp *http.Response, err error) { + err = cc.Write(req) + if err != nil { + return + } + return cc.Read(req) +} diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go new file mode 100644 index 000000000..ab4637018 --- /dev/null +++ b/src/net/http/httputil/reverseproxy.go @@ -0,0 +1,225 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package httputil + +import ( + "io" + "log" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// onExitFlushLoop is a callback set by tests to detect the state of the +// flushLoop() goroutine. +var onExitFlushLoop func() + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +type ReverseProxy struct { + // Director must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + Director func(*http.Request) + + // The transport used to perform proxy requests. + // If nil, http.DefaultTransport is used. + Transport http.RoundTripper + + // FlushInterval specifies the flush interval + // to flush to the client while copying the + // response body. + // If zero, no periodic flushing is done. + FlushInterval time.Duration + + // ErrorLog specifies an optional logger for errors + // that occur when attempting to proxy the request. + // If nil, logging goes to os.Stderr via the log package's + // standard logger. + ErrorLog *log.Logger +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { + targetQuery := target.RawQuery + director := func(req *http.Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } + } + return &ReverseProxy{Director: director} +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +var hopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailers", + "Transfer-Encoding", + "Upgrade", +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + transport := p.Transport + if transport == nil { + transport = http.DefaultTransport + } + + outreq := new(http.Request) + *outreq = *req // includes shallow copies of maps, but okay + + p.Director(outreq) + outreq.Proto = "HTTP/1.1" + outreq.ProtoMajor = 1 + outreq.ProtoMinor = 1 + outreq.Close = false + + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. This + // is modifying the same underlying map from req (shallow + // copied above) so we only copy it if necessary. + copiedHeaders := false + for _, h := range hopHeaders { + if outreq.Header.Get(h) != "" { + if !copiedHeaders { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, req.Header) + copiedHeaders = true + } + outreq.Header.Del(h) + } + } + + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + outreq.Header.Set("X-Forwarded-For", clientIP) + } + + res, err := transport.RoundTrip(outreq) + if err != nil { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusInternalServerError) + return + } + defer res.Body.Close() + + for _, h := range hopHeaders { + res.Header.Del(h) + } + + copyHeader(rw.Header(), res.Header) + + rw.WriteHeader(res.StatusCode) + p.copyResponse(rw, res.Body) +} + +func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { + if p.FlushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: p.FlushInterval, + done: make(chan bool), + } + go mlw.flushLoop() + defer mlw.stop() + dst = mlw + } + } + + io.Copy(dst, src) +} + +func (p *ReverseProxy) logf(format string, args ...interface{}) { + if p.ErrorLog != nil { + p.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +type writeFlusher interface { + io.Writer + http.Flusher +} + +type maxLatencyWriter struct { + dst writeFlusher + latency time.Duration + + lk sync.Mutex // protects Write + Flush + done chan bool +} + +func (m *maxLatencyWriter) Write(p []byte) (int, error) { + m.lk.Lock() + defer m.lk.Unlock() + return m.dst.Write(p) +} + +func (m *maxLatencyWriter) flushLoop() { + t := time.NewTicker(m.latency) + defer t.Stop() + for { + select { + case <-m.done: + if onExitFlushLoop != nil { + onExitFlushLoop() + } + return + case <-t.C: + m.lk.Lock() + m.dst.Flush() + m.lk.Unlock() + } + } +} + +func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go new file mode 100644 index 000000000..e9539b44b --- /dev/null +++ b/src/net/http/httputil/reverseproxy_test.go @@ -0,0 +1,213 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Reverse proxy tests. + +package httputil + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +const fakeHopHeader = "X-Fake-Hop-Header-For-Test" + +func init() { + hopHeaders = append(hopHeaders, fakeHopHeader) +} + +func TestReverseProxy(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(r.TransferEncoding) > 0 { + t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) + } + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if c := r.Header.Get("Connection"); c != "" { + t.Errorf("handler got Connection header value %q", c) + } + if c := r.Header.Get("Upgrade"); c != "" { + t.Errorf("handler got Upgrade header value %q", c) + } + if g, e := r.Host, "some-name"; g != e { + t.Errorf("backend got Host header %q, want %q", g, e) + } + w.Header().Set("X-Foo", "bar") + w.Header().Set("Upgrade", "foo") + w.Header().Set(fakeHopHeader, "foo") + w.Header().Add("X-Multi-Value", "foo") + w.Header().Add("X-Multi-Value", "bar") + http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Header.Set("Connection", "close") + getReq.Header.Set("Upgrade", "foo") + getReq.Close = true + res, err := http.DefaultClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if g, e := res.Header.Get("X-Foo"), "bar"; g != e { + t.Errorf("got X-Foo %q; expected %q", g, e) + } + if c := res.Header.Get(fakeHopHeader); c != "" { + t.Errorf("got %s header value %q", fakeHopHeader, c) + } + if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { + t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) + } + if g, e := len(res.Header["Set-Cookie"]), 1; g != e { + t.Fatalf("got %d SetCookies, want %d", g, e) + } + if cookie := res.Cookies()[0]; cookie.Name != "flavor" { + t.Errorf("unexpected cookie %q", cookie.Name) + } + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +func TestXForwardedFor(t *testing.T) { + const prevForwardedFor = "client ip" + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { + t.Errorf("X-Forwarded-For didn't contain prior data") + } + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Header.Set("Connection", "close") + getReq.Header.Set("X-Forwarded-For", prevForwardedFor) + getReq.Close = true + res, err := http.DefaultClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +var proxyQueryTests = []struct { + baseSuffix string // suffix to add to backend URL + reqSuffix string // suffix to add to frontend's request URL + want string // what backend should see for final request URL (without ?) +}{ + {"", "", ""}, + {"?sta=tic", "?us=er", "sta=tic&us=er"}, + {"", "?us=er", "us=er"}, + {"?sta=tic", "", "sta=tic"}, +} + +func TestReverseProxyQuery(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Got-Query", r.URL.RawQuery) + w.Write([]byte("hi")) + })) + defer backend.Close() + + for i, tt := range proxyQueryTests { + backendURL, err := url.Parse(backend.URL + tt.baseSuffix) + if err != nil { + t.Fatal(err) + } + frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) + req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) + req.Close = true + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("%d. Get: %v", i, err) + } + if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { + t.Errorf("%d. got query %q; expected %q", i, g, e) + } + res.Body.Close() + frontend.Close() + } +} + +func TestReverseProxyFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + done := make(chan bool) + onExitFlushLoop = func() { done <- true } + defer func() { onExitFlushLoop = nil }() + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } + + select { + case <-done: + // OK + case <-time.After(5 * time.Second): + t.Error("maxLatencyWriter flushLoop() never exited") + } +} diff --git a/src/net/http/internal/chunked.go b/src/net/http/internal/chunked.go new file mode 100644 index 000000000..9294deb3e --- /dev/null +++ b/src/net/http/internal/chunked.go @@ -0,0 +1,202 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The wire protocol for HTTP's "chunked" Transfer-Encoding. + +// Package internal contains HTTP internals shared by net/http and +// net/http/httputil. +package internal + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" +) + +const maxLineLength = 4096 // assumed <= bufio.defaultBufSize + +var ErrLineTooLong = errors.New("header line too long") + +// NewChunkedReader returns a new chunkedReader that translates the data read from r +// out of HTTP "chunked" format before returning it. +// The chunkedReader returns io.EOF when the final 0-length chunk is read. +// +// NewChunkedReader is not needed by normal applications. The http package +// automatically decodes chunking when reading response bodies. +func NewChunkedReader(r io.Reader) io.Reader { + br, ok := r.(*bufio.Reader) + if !ok { + br = bufio.NewReader(r) + } + return &chunkedReader{r: br} +} + +type chunkedReader struct { + r *bufio.Reader + n uint64 // unread bytes in chunk + err error + buf [2]byte +} + +func (cr *chunkedReader) beginChunk() { + // chunk-size CRLF + var line []byte + line, cr.err = readLine(cr.r) + if cr.err != nil { + return + } + cr.n, cr.err = parseHexUint(line) + if cr.err != nil { + return + } + if cr.n == 0 { + cr.err = io.EOF + } +} + +func (cr *chunkedReader) chunkHeaderAvailable() bool { + n := cr.r.Buffered() + if n > 0 { + peek, _ := cr.r.Peek(n) + return bytes.IndexByte(peek, '\n') >= 0 + } + return false +} + +func (cr *chunkedReader) Read(b []uint8) (n int, err error) { + for cr.err == nil { + if cr.n == 0 { + if n > 0 && !cr.chunkHeaderAvailable() { + // We've read enough. Don't potentially block + // reading a new chunk header. + break + } + cr.beginChunk() + continue + } + if len(b) == 0 { + break + } + rbuf := b + if uint64(len(rbuf)) > cr.n { + rbuf = rbuf[:cr.n] + } + var n0 int + n0, cr.err = cr.r.Read(rbuf) + n += n0 + b = b[n0:] + cr.n -= uint64(n0) + // If we're at the end of a chunk, read the next two + // bytes to verify they are "\r\n". + if cr.n == 0 && cr.err == nil { + if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil { + if cr.buf[0] != '\r' || cr.buf[1] != '\n' { + cr.err = errors.New("malformed chunked encoding") + } + } + } + } + return n, cr.err +} + +// Read a line of bytes (up to \n) from b. +// Give up if the line exceeds maxLineLength. +// The returned bytes are a pointer into storage in +// the bufio, so they are only valid until the next bufio read. +func readLine(b *bufio.Reader) (p []byte, err error) { + if p, err = b.ReadSlice('\n'); err != nil { + // We always know when EOF is coming. + // If the caller asked for a line, there should be a line. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } else if err == bufio.ErrBufferFull { + err = ErrLineTooLong + } + return nil, err + } + if len(p) >= maxLineLength { + return nil, ErrLineTooLong + } + return trimTrailingWhitespace(p), nil +} + +func trimTrailingWhitespace(b []byte) []byte { + for len(b) > 0 && isASCIISpace(b[len(b)-1]) { + b = b[:len(b)-1] + } + return b +} + +func isASCIISpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' +} + +// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream. +// +// NewChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using newChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func NewChunkedWriter(w io.Writer) io.WriteCloser { + return &chunkedWriter{w} +} + +// Writing to chunkedWriter translates to writing in HTTP chunked Transfer +// Encoding wire format to the underlying Wire chunkedWriter. +type chunkedWriter struct { + Wire io.Writer +} + +// Write the contents of data as one chunk to Wire. +// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has +// a bug since it does not check for success of io.WriteString +func (cw *chunkedWriter) Write(data []byte) (n int, err error) { + + // Don't send 0-length data. It looks like EOF for chunked encoding. + if len(data) == 0 { + return 0, nil + } + + if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil { + return 0, err + } + if n, err = cw.Wire.Write(data); err != nil { + return + } + if n != len(data) { + err = io.ErrShortWrite + return + } + _, err = io.WriteString(cw.Wire, "\r\n") + + return +} + +func (cw *chunkedWriter) Close() error { + _, err := io.WriteString(cw.Wire, "0\r\n") + return err +} + +func parseHexUint(v []byte) (n uint64, err error) { + for _, b := range v { + n <<= 4 + switch { + case '0' <= b && b <= '9': + b = b - '0' + case 'a' <= b && b <= 'f': + b = b - 'a' + 10 + case 'A' <= b && b <= 'F': + b = b - 'A' + 10 + default: + return 0, errors.New("invalid byte in chunk length") + } + n |= uint64(b) + } + return +} diff --git a/src/net/http/internal/chunked_test.go b/src/net/http/internal/chunked_test.go new file mode 100644 index 000000000..ebc626ea9 --- /dev/null +++ b/src/net/http/internal/chunked_test.go @@ -0,0 +1,156 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/ioutil" + "strings" + "testing" +) + +func TestChunk(t *testing.T) { + var b bytes.Buffer + + w := NewChunkedWriter(&b) + const chunk1 = "hello, " + const chunk2 = "world! 0123456789abcdef" + w.Write([]byte(chunk1)) + w.Write([]byte(chunk2)) + w.Close() + + if g, e := b.String(), "7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n"; g != e { + t.Fatalf("chunk writer wrote %q; want %q", g, e) + } + + r := NewChunkedReader(&b) + data, err := ioutil.ReadAll(r) + if err != nil { + t.Logf(`data: "%s"`, data) + t.Fatalf("ReadAll from reader: %v", err) + } + if g, e := string(data), chunk1+chunk2; g != e { + t.Errorf("chunk reader read %q; want %q", g, e) + } +} + +func TestChunkReadMultiple(t *testing.T) { + // Bunch of small chunks, all read together. + { + var b bytes.Buffer + w := NewChunkedWriter(&b) + w.Write([]byte("foo")) + w.Write([]byte("bar")) + w.Close() + + r := NewChunkedReader(&b) + buf := make([]byte, 10) + n, err := r.Read(buf) + if n != 6 || err != io.EOF { + t.Errorf("Read = %d, %v; want 6, EOF", n, err) + } + buf = buf[:n] + if string(buf) != "foobar" { + t.Errorf("Read = %q; want %q", buf, "foobar") + } + } + + // One big chunk followed by a little chunk, but the small bufio.Reader size + // should prevent the second chunk header from being read. + { + var b bytes.Buffer + w := NewChunkedWriter(&b) + // fillBufChunk is 11 bytes + 3 bytes header + 2 bytes footer = 16 bytes, + // the same as the bufio ReaderSize below (the minimum), so even + // though we're going to try to Read with a buffer larger enough to also + // receive "foo", the second chunk header won't be read yet. + const fillBufChunk = "0123456789a" + const shortChunk = "foo" + w.Write([]byte(fillBufChunk)) + w.Write([]byte(shortChunk)) + w.Close() + + r := NewChunkedReader(bufio.NewReaderSize(&b, 16)) + buf := make([]byte, len(fillBufChunk)+len(shortChunk)) + n, err := r.Read(buf) + if n != len(fillBufChunk) || err != nil { + t.Errorf("Read = %d, %v; want %d, nil", n, err, len(fillBufChunk)) + } + buf = buf[:n] + if string(buf) != fillBufChunk { + t.Errorf("Read = %q; want %q", buf, fillBufChunk) + } + + n, err = r.Read(buf) + if n != len(shortChunk) || err != io.EOF { + t.Errorf("Read = %d, %v; want %d, EOF", n, err, len(shortChunk)) + } + } + + // And test that we see an EOF chunk, even though our buffer is already full: + { + r := NewChunkedReader(bufio.NewReader(strings.NewReader("3\r\nfoo\r\n0\r\n"))) + buf := make([]byte, 3) + n, err := r.Read(buf) + if n != 3 || err != io.EOF { + t.Errorf("Read = %d, %v; want 3, EOF", n, err) + } + if string(buf) != "foo" { + t.Errorf("buf = %q; want foo", buf) + } + } +} + +func TestChunkReaderAllocs(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + var buf bytes.Buffer + w := NewChunkedWriter(&buf) + a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc") + w.Write(a) + w.Write(b) + w.Write(c) + w.Close() + + readBuf := make([]byte, len(a)+len(b)+len(c)+1) + byter := bytes.NewReader(buf.Bytes()) + bufr := bufio.NewReader(byter) + mallocs := testing.AllocsPerRun(100, func() { + byter.Seek(0, 0) + bufr.Reset(byter) + r := NewChunkedReader(bufr) + n, err := io.ReadFull(r, readBuf) + if n != len(readBuf)-1 { + t.Fatalf("read %d bytes; want %d", n, len(readBuf)-1) + } + if err != io.ErrUnexpectedEOF { + t.Fatalf("read error = %v; want ErrUnexpectedEOF", err) + } + }) + if mallocs > 1.5 { + t.Errorf("mallocs = %v; want 1", mallocs) + } +} + +func TestParseHexUint(t *testing.T) { + for i := uint64(0); i <= 1234; i++ { + line := []byte(fmt.Sprintf("%x", i)) + got, err := parseHexUint(line) + if err != nil { + t.Fatalf("on %d: %v", i, err) + } + if got != i { + t.Errorf("for input %q = %d; want %d", line, got, i) + } + } + _, err := parseHexUint([]byte("bogus")) + if err == nil { + t.Error("expected error on bogus input") + } +} diff --git a/src/net/http/jar.go b/src/net/http/jar.go new file mode 100644 index 000000000..5c3de0dad --- /dev/null +++ b/src/net/http/jar.go @@ -0,0 +1,27 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "net/url" +) + +// A CookieJar manages storage and use of cookies in HTTP requests. +// +// Implementations of CookieJar must be safe for concurrent use by multiple +// goroutines. +// +// The net/http/cookiejar package provides a CookieJar implementation. +type CookieJar interface { + // SetCookies handles the receipt of the cookies in a reply for the + // given URL. It may or may not choose to save the cookies, depending + // on the jar's policy and implementation. + SetCookies(u *url.URL, cookies []*Cookie) + + // Cookies returns the cookies to send in a request for the given URL. + // It is up to the implementation to honor the standard cookie use + // restrictions such as in RFC 6265. + Cookies(u *url.URL) []*Cookie +} diff --git a/src/net/http/lex.go b/src/net/http/lex.go new file mode 100644 index 000000000..cb33318f4 --- /dev/null +++ b/src/net/http/lex.go @@ -0,0 +1,96 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +// This file deals with lexical matters of HTTP + +var isTokenTable = [127]bool{ + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, +} + +func isToken(r rune) bool { + i := int(r) + return i < len(isTokenTable) && isTokenTable[i] +} + +func isNotToken(r rune) bool { + return !isToken(r) +} diff --git a/src/net/http/lex_test.go b/src/net/http/lex_test.go new file mode 100644 index 000000000..6d9d294f7 --- /dev/null +++ b/src/net/http/lex_test.go @@ -0,0 +1,31 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "testing" +) + +func isChar(c rune) bool { return c <= 127 } + +func isCtl(c rune) bool { return c <= 31 || c == 127 } + +func isSeparator(c rune) bool { + switch c { + case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t': + return true + } + return false +} + +func TestIsToken(t *testing.T) { + for i := 0; i <= 130; i++ { + r := rune(i) + expected := isChar(r) && !isCtl(r) && !isSeparator(r) + if isToken(r) != expected { + t.Errorf("isToken(0x%x) = %v", r, !expected) + } + } +} diff --git a/src/net/http/main_test.go b/src/net/http/main_test.go new file mode 100644 index 000000000..b8c71fd19 --- /dev/null +++ b/src/net/http/main_test.go @@ -0,0 +1,109 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "fmt" + "net/http" + "os" + "runtime" + "sort" + "strings" + "testing" + "time" +) + +func TestMain(m *testing.M) { + v := m.Run() + if v == 0 && goroutineLeaked() { + os.Exit(1) + } + os.Exit(v) +} + +func interestingGoroutines() (gs []string) { + buf := make([]byte, 2<<20) + buf = buf[:runtime.Stack(buf, true)] + for _, g := range strings.Split(string(buf), "\n\n") { + sl := strings.SplitN(g, "\n", 2) + if len(sl) != 2 { + continue + } + stack := strings.TrimSpace(sl[1]) + if stack == "" || + strings.Contains(stack, "created by net.startServer") || + strings.Contains(stack, "created by testing.RunTests") || + strings.Contains(stack, "closeWriteAndWait") || + strings.Contains(stack, "testing.Main(") || + // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28) + strings.Contains(stack, "runtime.goexit") || + strings.Contains(stack, "created by runtime.gc") || + strings.Contains(stack, "net/http_test.interestingGoroutines") || + strings.Contains(stack, "runtime.MHeap_Scavenger") { + continue + } + gs = append(gs, stack) + } + sort.Strings(gs) + return +} + +// Verify the other tests didn't leave any goroutines running. +func goroutineLeaked() bool { + if testing.Short() { + // not counting goroutines for leakage in -short mode + return false + } + gs := interestingGoroutines() + + n := 0 + stackCount := make(map[string]int) + for _, g := range gs { + stackCount[g]++ + n++ + } + + if n == 0 { + return false + } + fmt.Fprintf(os.Stderr, "Too many goroutines running after net/http test(s).\n") + for stack, count := range stackCount { + fmt.Fprintf(os.Stderr, "%d instances of:\n%s\n", count, stack) + } + return true +} + +func afterTest(t *testing.T) { + http.DefaultTransport.(*http.Transport).CloseIdleConnections() + if testing.Short() { + return + } + var bad string + badSubstring := map[string]string{ + ").readLoop(": "a Transport", + ").writeLoop(": "a Transport", + "created by net/http/httptest.(*Server).Start": "an httptest.Server", + "timeoutHandler": "a TimeoutHandler", + "net.(*netFD).connect(": "a timing out dial", + ").noteClientGone(": "a closenotifier sender", + } + var stacks string + for i := 0; i < 4; i++ { + bad = "" + stacks = strings.Join(interestingGoroutines(), "\n\n") + for substr, what := range badSubstring { + if strings.Contains(stacks, substr) { + bad = what + } + } + if bad == "" { + return + } + // Bad stuff found, but goroutines might just still be + // shutting down, so give it some time. + time.Sleep(250 * time.Millisecond) + } + t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks) +} diff --git a/src/net/http/npn_test.go b/src/net/http/npn_test.go new file mode 100644 index 000000000..98b8930d0 --- /dev/null +++ b/src/net/http/npn_test.go @@ -0,0 +1,118 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "bufio" + "crypto/tls" + "fmt" + "io" + "io/ioutil" + . "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestNextProtoUpgrade(t *testing.T) { + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "path=%s,proto=", r.URL.Path) + if r.TLS != nil { + w.Write([]byte(r.TLS.NegotiatedProtocol)) + } + if r.RemoteAddr == "" { + t.Error("request with no RemoteAddr") + } + if r.Body == nil { + t.Errorf("request with nil Body") + } + })) + ts.TLS = &tls.Config{ + NextProtos: []string{"unhandled-proto", "tls-0.9"}, + } + ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){ + "tls-0.9": handleTLSProtocol09, + } + ts.StartTLS() + defer ts.Close() + + tr := newTLSTransport(t, ts) + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + // Normal request, without NPN. + { + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if want := "path=/,proto="; string(body) != want { + t.Errorf("plain request = %q; want %q", body, want) + } + } + + // Request to an advertised but unhandled NPN protocol. + // Server will hang up. + { + tr.CloseIdleConnections() + tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"} + _, err := c.Get(ts.URL) + if err == nil { + t.Errorf("expected error on unhandled-proto request") + } + } + + // Request using the "tls-0.9" protocol, which we register here. + // It is HTTP/0.9 over TLS. + { + tlsConfig := newTLSTransport(t, ts).TLSClientConfig + tlsConfig.NextProtos = []string{"tls-0.9"} + conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) + if err != nil { + t.Fatal(err) + } + conn.Write([]byte("GET /foo\n")) + body, err := ioutil.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if want := "path=/foo,proto=tls-0.9"; string(body) != want { + t.Errorf("plain request = %q; want %q", body, want) + } + } +} + +// handleTLSProtocol09 implements the HTTP/0.9 protocol over TLS, for the +// TestNextProtoUpgrade test. +func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) { + br := bufio.NewReader(conn) + line, err := br.ReadString('\n') + if err != nil { + return + } + line = strings.TrimSpace(line) + path := strings.TrimPrefix(line, "GET ") + if path == line { + return + } + req, _ := NewRequest("GET", path, nil) + req.Proto = "HTTP/0.9" + req.ProtoMajor = 0 + req.ProtoMinor = 9 + rw := &http09Writer{conn, make(Header)} + h.ServeHTTP(rw, req) +} + +type http09Writer struct { + io.Writer + h Header +} + +func (w http09Writer) Header() Header { return w.h } +func (w http09Writer) WriteHeader(int) {} // no headers diff --git a/src/net/http/pprof/pprof.go b/src/net/http/pprof/pprof.go new file mode 100644 index 000000000..a23f1bc4b --- /dev/null +++ b/src/net/http/pprof/pprof.go @@ -0,0 +1,209 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package pprof serves via its HTTP server runtime profiling data +// in the format expected by the pprof visualization tool. +// For more information about pprof, see +// http://code.google.com/p/google-perftools/. +// +// The package is typically only imported for the side effect of +// registering its HTTP handlers. +// The handled paths all begin with /debug/pprof/. +// +// To use pprof, link this package into your program: +// import _ "net/http/pprof" +// +// If your application is not already running an http server, you +// need to start one. Add "net/http" and "log" to your imports and +// the following code to your main function: +// +// go func() { +// log.Println(http.ListenAndServe("localhost:6060", nil)) +// }() +// +// Then use the pprof tool to look at the heap profile: +// +// go tool pprof http://localhost:6060/debug/pprof/heap +// +// Or to look at a 30-second CPU profile: +// +// go tool pprof http://localhost:6060/debug/pprof/profile +// +// Or to look at the goroutine blocking profile: +// +// go tool pprof http://localhost:6060/debug/pprof/block +// +// To view all available profiles, open http://localhost:6060/debug/pprof/ +// in your browser. +// +// For a study of the facility in action, visit +// +// http://blog.golang.org/2011/06/profiling-go-programs.html +// +package pprof + +import ( + "bufio" + "bytes" + "fmt" + "html/template" + "io" + "log" + "net/http" + "os" + "runtime" + "runtime/pprof" + "strconv" + "strings" + "time" +) + +func init() { + http.Handle("/debug/pprof/", http.HandlerFunc(Index)) + http.Handle("/debug/pprof/cmdline", http.HandlerFunc(Cmdline)) + http.Handle("/debug/pprof/profile", http.HandlerFunc(Profile)) + http.Handle("/debug/pprof/symbol", http.HandlerFunc(Symbol)) +} + +// Cmdline responds with the running program's +// command line, with arguments separated by NUL bytes. +// The package initialization registers it as /debug/pprof/cmdline. +func Cmdline(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + fmt.Fprintf(w, strings.Join(os.Args, "\x00")) +} + +// Profile responds with the pprof-formatted cpu profile. +// The package initialization registers it as /debug/pprof/profile. +func Profile(w http.ResponseWriter, r *http.Request) { + sec, _ := strconv.ParseInt(r.FormValue("seconds"), 10, 64) + if sec == 0 { + sec = 30 + } + + // Set Content Type assuming StartCPUProfile will work, + // because if it does it starts writing. + w.Header().Set("Content-Type", "application/octet-stream") + if err := pprof.StartCPUProfile(w); err != nil { + // StartCPUProfile failed, so no writes yet. + // Can change header back to text content + // and send error code. + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err) + return + } + time.Sleep(time.Duration(sec) * time.Second) + pprof.StopCPUProfile() +} + +// Symbol looks up the program counters listed in the request, +// responding with a table mapping program counters to function names. +// The package initialization registers it as /debug/pprof/symbol. +func Symbol(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + + // We have to read the whole POST body before + // writing any output. Buffer the output here. + var buf bytes.Buffer + + // We don't know how many symbols we have, but we + // do have symbol information. Pprof only cares whether + // this number is 0 (no symbols available) or > 0. + fmt.Fprintf(&buf, "num_symbols: 1\n") + + var b *bufio.Reader + if r.Method == "POST" { + b = bufio.NewReader(r.Body) + } else { + b = bufio.NewReader(strings.NewReader(r.URL.RawQuery)) + } + + for { + word, err := b.ReadSlice('+') + if err == nil { + word = word[0 : len(word)-1] // trim + + } + pc, _ := strconv.ParseUint(string(word), 0, 64) + if pc != 0 { + f := runtime.FuncForPC(uintptr(pc)) + if f != nil { + fmt.Fprintf(&buf, "%#x %s\n", pc, f.Name()) + } + } + + // Wait until here to check for err; the last + // symbol will have an err because it doesn't end in +. + if err != nil { + if err != io.EOF { + fmt.Fprintf(&buf, "reading request: %v\n", err) + } + break + } + } + + w.Write(buf.Bytes()) +} + +// Handler returns an HTTP handler that serves the named profile. +func Handler(name string) http.Handler { + return handler(name) +} + +type handler string + +func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + debug, _ := strconv.Atoi(r.FormValue("debug")) + p := pprof.Lookup(string(name)) + if p == nil { + w.WriteHeader(404) + fmt.Fprintf(w, "Unknown profile: %s\n", name) + return + } + gc, _ := strconv.Atoi(r.FormValue("gc")) + if name == "heap" && gc > 0 { + runtime.GC() + } + p.WriteTo(w, debug) + return +} + +// Index responds with the pprof-formatted profile named by the request. +// For example, "/debug/pprof/heap" serves the "heap" profile. +// Index responds to a request for "/debug/pprof/" with an HTML page +// listing the available profiles. +func Index(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/debug/pprof/") { + name := strings.TrimPrefix(r.URL.Path, "/debug/pprof/") + if name != "" { + handler(name).ServeHTTP(w, r) + return + } + } + + profiles := pprof.Profiles() + if err := indexTmpl.Execute(w, profiles); err != nil { + log.Print(err) + } +} + +var indexTmpl = template.Must(template.New("index").Parse(`<html> +<head> +<title>/debug/pprof/</title> +</head> +/debug/pprof/<br> +<br> +<body> +profiles:<br> +<table> +{{range .}} +<tr><td align=right>{{.Count}}<td><a href="/debug/pprof/{{.Name}}?debug=1">{{.Name}}</a> +{{end}} +</table> +<br> +<a href="/debug/pprof/goroutine?debug=2">full goroutine stack dump</a><br> +</body> +</html> +`)) diff --git a/src/net/http/proxy_test.go b/src/net/http/proxy_test.go new file mode 100644 index 000000000..b6aed3792 --- /dev/null +++ b/src/net/http/proxy_test.go @@ -0,0 +1,81 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "net/url" + "os" + "testing" +) + +// TODO(mattn): +// test ProxyAuth + +var UseProxyTests = []struct { + host string + match bool +}{ + // Never proxy localhost: + {"localhost:80", false}, + {"127.0.0.1", false}, + {"127.0.0.2", false}, + {"[::1]", false}, + {"[::2]", true}, // not a loopback address + + {"barbaz.net", false}, // match as .barbaz.net + {"foobar.com", false}, // have a port but match + {"foofoobar.com", true}, // not match as a part of foobar.com + {"baz.com", true}, // not match as a part of barbaz.com + {"localhost.net", true}, // not match as suffix of address + {"local.localhost", true}, // not match as prefix as address + {"barbarbaz.net", true}, // not match because NO_PROXY have a '.' + {"www.foobar.com", false}, // match because NO_PROXY includes "foobar.com" +} + +func TestUseProxy(t *testing.T) { + ResetProxyEnv() + os.Setenv("NO_PROXY", "foobar.com, .barbaz.net") + for _, test := range UseProxyTests { + if useProxy(test.host+":80") != test.match { + t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) + } + } +} + +var cacheKeysTests = []struct { + proxy string + scheme string + addr string + key string +}{ + {"", "http", "foo.com", "|http|foo.com"}, + {"", "https", "foo.com", "|https|foo.com"}, + {"http://foo.com", "http", "foo.com", "http://foo.com|http|"}, + {"http://foo.com", "https", "foo.com", "http://foo.com|https|foo.com"}, +} + +func TestCacheKeys(t *testing.T) { + for _, tt := range cacheKeysTests { + var proxy *url.URL + if tt.proxy != "" { + u, err := url.Parse(tt.proxy) + if err != nil { + t.Fatal(err) + } + proxy = u + } + cm := connectMethod{proxy, tt.scheme, tt.addr} + if got := cm.key().String(); got != tt.key { + t.Fatalf("{%q, %q, %q} cache key = %q; want %q", tt.proxy, tt.scheme, tt.addr, got, tt.key) + } + } +} + +func ResetProxyEnv() { + for _, v := range []string{"HTTP_PROXY", "http_proxy", "NO_PROXY", "no_proxy"} { + os.Setenv(v, "") + } + ResetCachedEnvironment() +} diff --git a/src/net/http/race.go b/src/net/http/race.go new file mode 100644 index 000000000..766503967 --- /dev/null +++ b/src/net/http/race.go @@ -0,0 +1,11 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build race + +package http + +func init() { + raceEnabled = true +} diff --git a/src/net/http/range_test.go b/src/net/http/range_test.go new file mode 100644 index 000000000..ef911af7b --- /dev/null +++ b/src/net/http/range_test.go @@ -0,0 +1,79 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "testing" +) + +var ParseRangeTests = []struct { + s string + length int64 + r []httpRange +}{ + {"", 0, nil}, + {"", 1000, nil}, + {"foo", 0, nil}, + {"bytes=", 0, nil}, + {"bytes=7", 10, nil}, + {"bytes= 7 ", 10, nil}, + {"bytes=1-", 0, nil}, + {"bytes=5-4", 10, nil}, + {"bytes=0-2,5-4", 10, nil}, + {"bytes=2-5,4-3", 10, nil}, + {"bytes=--5,4--3", 10, nil}, + {"bytes=A-", 10, nil}, + {"bytes=A- ", 10, nil}, + {"bytes=A-Z", 10, nil}, + {"bytes= -Z", 10, nil}, + {"bytes=5-Z", 10, nil}, + {"bytes=Ran-dom, garbage", 10, nil}, + {"bytes=0x01-0x02", 10, nil}, + {"bytes= ", 10, nil}, + {"bytes= , , , ", 10, nil}, + + {"bytes=0-9", 10, []httpRange{{0, 10}}}, + {"bytes=0-", 10, []httpRange{{0, 10}}}, + {"bytes=5-", 10, []httpRange{{5, 5}}}, + {"bytes=0-20", 10, []httpRange{{0, 10}}}, + {"bytes=15-,0-5", 10, nil}, + {"bytes=1-2,5-", 10, []httpRange{{1, 2}, {5, 5}}}, + {"bytes=-2 , 7-", 11, []httpRange{{9, 2}, {7, 4}}}, + {"bytes=0-0 ,2-2, 7-", 11, []httpRange{{0, 1}, {2, 1}, {7, 4}}}, + {"bytes=-5", 10, []httpRange{{5, 5}}}, + {"bytes=-15", 10, []httpRange{{0, 10}}}, + {"bytes=0-499", 10000, []httpRange{{0, 500}}}, + {"bytes=500-999", 10000, []httpRange{{500, 500}}}, + {"bytes=-500", 10000, []httpRange{{9500, 500}}}, + {"bytes=9500-", 10000, []httpRange{{9500, 500}}}, + {"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}}, + {"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}}, + {"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}}, + + // Match Apache laxity: + {"bytes= 1 -2 , 4- 5, 7 - 8 , ,,", 11, []httpRange{{1, 2}, {4, 2}, {7, 2}}}, +} + +func TestParseRange(t *testing.T) { + for _, test := range ParseRangeTests { + r := test.r + ranges, err := parseRange(test.s, test.length) + if err != nil && r != nil { + t.Errorf("parseRange(%q) returned error %q", test.s, err) + } + if len(ranges) != len(r) { + t.Errorf("len(parseRange(%q)) = %d, want %d", test.s, len(ranges), len(r)) + continue + } + for i := range r { + if ranges[i].start != r[i].start { + t.Errorf("parseRange(%q)[%d].start = %d, want %d", test.s, i, ranges[i].start, r[i].start) + } + if ranges[i].length != r[i].length { + t.Errorf("parseRange(%q)[%d].length = %d, want %d", test.s, i, ranges[i].length, r[i].length) + } + } + } +} diff --git a/src/net/http/readrequest_test.go b/src/net/http/readrequest_test.go new file mode 100644 index 000000000..e930d99af --- /dev/null +++ b/src/net/http/readrequest_test.go @@ -0,0 +1,358 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/url" + "reflect" + "strings" + "testing" +) + +type reqTest struct { + Raw string + Req *Request + Body string + Trailer Header + Error string +} + +var noError = "" +var noBody = "" +var noTrailer Header = nil + +var reqTests = []reqTest{ + // Baseline test; All Request fields included for template use + { + "GET http://www.techcrunch.com/ HTTP/1.1\r\n" + + "Host: www.techcrunch.com\r\n" + + "User-Agent: Fake\r\n" + + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" + + "Accept-Language: en-us,en;q=0.5\r\n" + + "Accept-Encoding: gzip,deflate\r\n" + + "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" + + "Keep-Alive: 300\r\n" + + "Content-Length: 7\r\n" + + "Proxy-Connection: keep-alive\r\n\r\n" + + "abcdef\n???", + + &Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.techcrunch.com", + Path: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}, + "Accept-Language": {"en-us,en;q=0.5"}, + "Accept-Encoding": {"gzip,deflate"}, + "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"}, + "Keep-Alive": {"300"}, + "Proxy-Connection": {"keep-alive"}, + "Content-Length": {"7"}, + "User-Agent": {"Fake"}, + }, + Close: false, + ContentLength: 7, + Host: "www.techcrunch.com", + RequestURI: "http://www.techcrunch.com/", + }, + + "abcdef\n", + + noTrailer, + noError, + }, + + // GET request with no body (the normal case) + { + "GET / HTTP/1.1\r\n" + + "Host: foo.com\r\n\r\n", + + &Request{ + Method: "GET", + URL: &url.URL{ + Path: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: false, + ContentLength: 0, + Host: "foo.com", + RequestURI: "/", + }, + + noBody, + noTrailer, + noError, + }, + + // Tests that we don't parse a path that looks like a + // scheme-relative URI as a scheme-relative URI. + { + "GET //user@host/is/actually/a/path/ HTTP/1.1\r\n" + + "Host: test\r\n\r\n", + + &Request{ + Method: "GET", + URL: &url.URL{ + Path: "//user@host/is/actually/a/path/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: false, + ContentLength: 0, + Host: "test", + RequestURI: "//user@host/is/actually/a/path/", + }, + + noBody, + noTrailer, + noError, + }, + + // Tests a bogus abs_path on the Request-Line (RFC 2616 section 5.1.2) + { + "GET ../../../../etc/passwd HTTP/1.1\r\n" + + "Host: test\r\n\r\n", + nil, + noBody, + noTrailer, + "parse ../../../../etc/passwd: invalid URI for request", + }, + + // Tests missing URL: + { + "GET HTTP/1.1\r\n" + + "Host: test\r\n\r\n", + nil, + noBody, + noTrailer, + "parse : empty url", + }, + + // Tests chunked body with trailer: + { + "POST / HTTP/1.1\r\n" + + "Host: foo.com\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "3\r\nfoo\r\n" + + "3\r\nbar\r\n" + + "0\r\n" + + "Trailer-Key: Trailer-Value\r\n" + + "\r\n", + &Request{ + Method: "POST", + URL: &url.URL{ + Path: "/", + }, + TransferEncoding: []string{"chunked"}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + ContentLength: -1, + Host: "foo.com", + RequestURI: "/", + }, + + "foobar", + Header{ + "Trailer-Key": {"Trailer-Value"}, + }, + noError, + }, + + // CONNECT request with domain name: + { + "CONNECT www.google.com:443 HTTP/1.1\r\n\r\n", + + &Request{ + Method: "CONNECT", + URL: &url.URL{ + Host: "www.google.com:443", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: false, + ContentLength: 0, + Host: "www.google.com:443", + RequestURI: "www.google.com:443", + }, + + noBody, + noTrailer, + noError, + }, + + // CONNECT request with IP address: + { + "CONNECT 127.0.0.1:6060 HTTP/1.1\r\n\r\n", + + &Request{ + Method: "CONNECT", + URL: &url.URL{ + Host: "127.0.0.1:6060", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: false, + ContentLength: 0, + Host: "127.0.0.1:6060", + RequestURI: "127.0.0.1:6060", + }, + + noBody, + noTrailer, + noError, + }, + + // CONNECT request for RPC: + { + "CONNECT /_goRPC_ HTTP/1.1\r\n\r\n", + + &Request{ + Method: "CONNECT", + URL: &url.URL{ + Path: "/_goRPC_", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: false, + ContentLength: 0, + Host: "", + RequestURI: "/_goRPC_", + }, + + noBody, + noTrailer, + noError, + }, + + // SSDP Notify request. golang.org/issue/3692 + { + "NOTIFY * HTTP/1.1\r\nServer: foo\r\n\r\n", + &Request{ + Method: "NOTIFY", + URL: &url.URL{ + Path: "*", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Server": []string{"foo"}, + }, + Close: false, + ContentLength: 0, + RequestURI: "*", + }, + + noBody, + noTrailer, + noError, + }, + + // OPTIONS request. Similar to golang.org/issue/3692 + { + "OPTIONS * HTTP/1.1\r\nServer: foo\r\n\r\n", + &Request{ + Method: "OPTIONS", + URL: &url.URL{ + Path: "*", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Server": []string{"foo"}, + }, + Close: false, + ContentLength: 0, + RequestURI: "*", + }, + + noBody, + noTrailer, + noError, + }, + + // Connection: close. golang.org/issue/8261 + { + "GET / HTTP/1.1\r\nHost: issue8261.com\r\nConnection: close\r\n\r\n", + &Request{ + Method: "GET", + URL: &url.URL{ + Path: "/", + }, + Header: Header{ + // This wasn't removed from Go 1.0 to + // Go 1.3, so locking it in that we + // keep this: + "Connection": []string{"close"}, + }, + Host: "issue8261.com", + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Close: true, + RequestURI: "/", + }, + + noBody, + noTrailer, + noError, + }, +} + +func TestReadRequest(t *testing.T) { + for i := range reqTests { + tt := &reqTests[i] + req, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.Raw))) + if err != nil { + if err.Error() != tt.Error { + t.Errorf("#%d: error %q, want error %q", i, err.Error(), tt.Error) + } + continue + } + rbody := req.Body + req.Body = nil + testName := fmt.Sprintf("Test %d (%q)", i, tt.Raw) + diff(t, testName, req, tt.Req) + var bout bytes.Buffer + if rbody != nil { + _, err := io.Copy(&bout, rbody) + if err != nil { + t.Fatalf("%s: copying body: %v", testName, err) + } + rbody.Close() + } + body := bout.String() + if body != tt.Body { + t.Errorf("%s: Body = %q want %q", testName, body, tt.Body) + } + if !reflect.DeepEqual(tt.Trailer, req.Trailer) { + t.Errorf("%s: Trailers differ.\n got: %v\nwant: %v", testName, req.Trailer, tt.Trailer) + } + } +} diff --git a/src/net/http/request.go b/src/net/http/request.go new file mode 100644 index 000000000..487eebcb8 --- /dev/null +++ b/src/net/http/request.go @@ -0,0 +1,921 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP Request reading and parsing. + +package http + +import ( + "bufio" + "bytes" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "io/ioutil" + "mime" + "mime/multipart" + "net/textproto" + "net/url" + "strconv" + "strings" + "sync" +) + +const ( + maxValueLength = 4096 + maxHeaderLines = 1024 + chunkSize = 4 << 10 // 4 KB chunks + defaultMaxMemory = 32 << 20 // 32 MB +) + +// ErrMissingFile is returned by FormFile when the provided file field name +// is either not present in the request or not a file field. +var ErrMissingFile = errors.New("http: no such file") + +// HTTP request parsing errors. +type ProtocolError struct { + ErrorString string +} + +func (err *ProtocolError) Error() string { return err.ErrorString } + +var ( + ErrHeaderTooLong = &ProtocolError{"header too long"} + ErrShortBody = &ProtocolError{"entity body too short"} + ErrNotSupported = &ProtocolError{"feature not supported"} + ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"} + ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"} + ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"} + ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"} +) + +type badStringError struct { + what string + str string +} + +func (e *badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) } + +// Headers that Request.Write handles itself and should be skipped. +var reqWriteExcludeHeader = map[string]bool{ + "Host": true, // not in Header map anyway + "User-Agent": true, + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + +// A Request represents an HTTP request received by a server +// or to be sent by a client. +// +// The field semantics differ slightly between client and server +// usage. In addition to the notes on the fields below, see the +// documentation for Request.Write and RoundTripper. +type Request struct { + // Method specifies the HTTP method (GET, POST, PUT, etc.). + // For client requests an empty string means GET. + Method string + + // URL specifies either the URI being requested (for server + // requests) or the URL to access (for client requests). + // + // For server requests the URL is parsed from the URI + // supplied on the Request-Line as stored in RequestURI. For + // most requests, fields other than Path and RawQuery will be + // empty. (See RFC 2616, Section 5.1.2) + // + // For client requests, the URL's Host specifies the server to + // connect to, while the Request's Host field optionally + // specifies the Host header value to send in the HTTP + // request. + URL *url.URL + + // The protocol version for incoming requests. + // Client requests always use HTTP/1.1. + Proto string // "HTTP/1.0" + ProtoMajor int // 1 + ProtoMinor int // 0 + + // A header maps request lines to their values. + // If the header says + // + // accept-encoding: gzip, deflate + // Accept-Language: en-us + // Connection: keep-alive + // + // then + // + // Header = map[string][]string{ + // "Accept-Encoding": {"gzip, deflate"}, + // "Accept-Language": {"en-us"}, + // "Connection": {"keep-alive"}, + // } + // + // HTTP defines that header names are case-insensitive. + // The request parser implements this by canonicalizing the + // name, making the first character and any characters + // following a hyphen uppercase and the rest lowercase. + // + // For client requests certain headers are automatically + // added and may override values in Header. + // + // See the documentation for the Request.Write method. + Header Header + + // Body is the request's body. + // + // For client requests a nil body means the request has no + // body, such as a GET request. The HTTP Client's Transport + // is responsible for calling the Close method. + // + // For server requests the Request Body is always non-nil + // but will return EOF immediately when no body is present. + // The Server will close the request body. The ServeHTTP + // Handler does not need to. + Body io.ReadCloser + + // ContentLength records the length of the associated content. + // The value -1 indicates that the length is unknown. + // Values >= 0 indicate that the given number of bytes may + // be read from Body. + // For client requests, a value of 0 means unknown if Body is not nil. + ContentLength int64 + + // TransferEncoding lists the transfer encodings from outermost to + // innermost. An empty list denotes the "identity" encoding. + // TransferEncoding can usually be ignored; chunked encoding is + // automatically added and removed as necessary when sending and + // receiving requests. + TransferEncoding []string + + // Close indicates whether to close the connection after + // replying to this request (for servers) or after sending + // the request (for clients). + Close bool + + // For server requests Host specifies the host on which the + // URL is sought. Per RFC 2616, this is either the value of + // the "Host" header or the host name given in the URL itself. + // It may be of the form "host:port". + // + // For client requests Host optionally overrides the Host + // header to send. If empty, the Request.Write method uses + // the value of URL.Host. + Host string + + // Form contains the parsed form data, including both the URL + // field's query parameters and the POST or PUT form data. + // This field is only available after ParseForm is called. + // The HTTP client ignores Form and uses Body instead. + Form url.Values + + // PostForm contains the parsed form data from POST or PUT + // body parameters. + // This field is only available after ParseForm is called. + // The HTTP client ignores PostForm and uses Body instead. + PostForm url.Values + + // MultipartForm is the parsed multipart form, including file uploads. + // This field is only available after ParseMultipartForm is called. + // The HTTP client ignores MultipartForm and uses Body instead. + MultipartForm *multipart.Form + + // Trailer specifies additional headers that are sent after the request + // body. + // + // For server requests the Trailer map initially contains only the + // trailer keys, with nil values. (The client declares which trailers it + // will later send.) While the handler is reading from Body, it must + // not reference Trailer. After reading from Body returns EOF, Trailer + // can be read again and will contain non-nil values, if they were sent + // by the client. + // + // For client requests Trailer must be initialized to a map containing + // the trailer keys to later send. The values may be nil or their final + // values. The ContentLength must be 0 or -1, to send a chunked request. + // After the HTTP request is sent the map values can be updated while + // the request body is read. Once the body returns EOF, the caller must + // not mutate Trailer. + // + // Few HTTP clients, servers, or proxies support HTTP trailers. + Trailer Header + + // RemoteAddr allows HTTP servers and other software to record + // the network address that sent the request, usually for + // logging. This field is not filled in by ReadRequest and + // has no defined format. The HTTP server in this package + // sets RemoteAddr to an "IP:port" address before invoking a + // handler. + // This field is ignored by the HTTP client. + RemoteAddr string + + // RequestURI is the unmodified Request-URI of the + // Request-Line (RFC 2616, Section 5.1) as sent by the client + // to a server. Usually the URL field should be used instead. + // It is an error to set this field in an HTTP client request. + RequestURI string + + // TLS allows HTTP servers and other software to record + // information about the TLS connection on which the request + // was received. This field is not filled in by ReadRequest. + // The HTTP server in this package sets the field for + // TLS-enabled connections before invoking a handler; + // otherwise it leaves the field nil. + // This field is ignored by the HTTP client. + TLS *tls.ConnectionState +} + +// ProtoAtLeast reports whether the HTTP protocol used +// in the request is at least major.minor. +func (r *Request) ProtoAtLeast(major, minor int) bool { + return r.ProtoMajor > major || + r.ProtoMajor == major && r.ProtoMinor >= minor +} + +// UserAgent returns the client's User-Agent, if sent in the request. +func (r *Request) UserAgent() string { + return r.Header.Get("User-Agent") +} + +// Cookies parses and returns the HTTP cookies sent with the request. +func (r *Request) Cookies() []*Cookie { + return readCookies(r.Header, "") +} + +var ErrNoCookie = errors.New("http: named cookie not present") + +// Cookie returns the named cookie provided in the request or +// ErrNoCookie if not found. +func (r *Request) Cookie(name string) (*Cookie, error) { + for _, c := range readCookies(r.Header, name) { + return c, nil + } + return nil, ErrNoCookie +} + +// AddCookie adds a cookie to the request. Per RFC 6265 section 5.4, +// AddCookie does not attach more than one Cookie header field. That +// means all cookies, if any, are written into the same line, +// separated by semicolon. +func (r *Request) AddCookie(c *Cookie) { + s := fmt.Sprintf("%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value)) + if c := r.Header.Get("Cookie"); c != "" { + r.Header.Set("Cookie", c+"; "+s) + } else { + r.Header.Set("Cookie", s) + } +} + +// Referer returns the referring URL, if sent in the request. +// +// Referer is misspelled as in the request itself, a mistake from the +// earliest days of HTTP. This value can also be fetched from the +// Header map as Header["Referer"]; the benefit of making it available +// as a method is that the compiler can diagnose programs that use the +// alternate (correct English) spelling req.Referrer() but cannot +// diagnose programs that use Header["Referrer"]. +func (r *Request) Referer() string { + return r.Header.Get("Referer") +} + +// multipartByReader is a sentinel value. +// Its presence in Request.MultipartForm indicates that parsing of the request +// body has been handed off to a MultipartReader instead of ParseMultipartFrom. +var multipartByReader = &multipart.Form{ + Value: make(map[string][]string), + File: make(map[string][]*multipart.FileHeader), +} + +// MultipartReader returns a MIME multipart reader if this is a +// multipart/form-data POST request, else returns nil and an error. +// Use this function instead of ParseMultipartForm to +// process the request body as a stream. +func (r *Request) MultipartReader() (*multipart.Reader, error) { + if r.MultipartForm == multipartByReader { + return nil, errors.New("http: MultipartReader called twice") + } + if r.MultipartForm != nil { + return nil, errors.New("http: multipart handled by ParseMultipartForm") + } + r.MultipartForm = multipartByReader + return r.multipartReader() +} + +func (r *Request) multipartReader() (*multipart.Reader, error) { + v := r.Header.Get("Content-Type") + if v == "" { + return nil, ErrNotMultipart + } + d, params, err := mime.ParseMediaType(v) + if err != nil || d != "multipart/form-data" { + return nil, ErrNotMultipart + } + boundary, ok := params["boundary"] + if !ok { + return nil, ErrMissingBoundary + } + return multipart.NewReader(r.Body, boundary), nil +} + +// Return value if nonempty, def otherwise. +func valueOrDefault(value, def string) string { + if value != "" { + return value + } + return def +} + +// NOTE: This is not intended to reflect the actual Go version being used. +// It was changed from "Go http package" to "Go 1.1 package http" at the +// time of the Go 1.1 release because the former User-Agent had ended up +// on a blacklist for some intrusion detection systems. +// See https://codereview.appspot.com/7532043. +const defaultUserAgent = "Go 1.1 package http" + +// Write writes an HTTP/1.1 request -- header and body -- in wire format. +// This method consults the following fields of the request: +// Host +// URL +// Method (defaults to "GET") +// Header +// ContentLength +// TransferEncoding +// Body +// +// If Body is present, Content-Length is <= 0 and TransferEncoding +// hasn't been set to "identity", Write adds "Transfer-Encoding: +// chunked" to the header. Body is closed after it is sent. +func (r *Request) Write(w io.Writer) error { + return r.write(w, false, nil) +} + +// WriteProxy is like Write but writes the request in the form +// expected by an HTTP proxy. In particular, WriteProxy writes the +// initial Request-URI line of the request with an absolute URI, per +// section 5.1.2 of RFC 2616, including the scheme and host. +// In either case, WriteProxy also writes a Host header, using +// either r.Host or r.URL.Host. +func (r *Request) WriteProxy(w io.Writer) error { + return r.write(w, true, nil) +} + +// extraHeaders may be nil +func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) error { + host := req.Host + if host == "" { + if req.URL == nil { + return errors.New("http: Request.Write on Request with no Host or URL set") + } + host = req.URL.Host + } + + ruri := req.URL.RequestURI() + if usingProxy && req.URL.Scheme != "" && req.URL.Opaque == "" { + ruri = req.URL.Scheme + "://" + host + ruri + } else if req.Method == "CONNECT" && req.URL.Path == "" { + // CONNECT requests normally give just the host and port, not a full URL. + ruri = host + } + // TODO(bradfitz): escape at least newlines in ruri? + + // Wrap the writer in a bufio Writer if it's not already buffered. + // Don't always call NewWriter, as that forces a bytes.Buffer + // and other small bufio Writers to have a minimum 4k buffer + // size. + var bw *bufio.Writer + if _, ok := w.(io.ByteWriter); !ok { + bw = bufio.NewWriter(w) + w = bw + } + + _, err := fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri) + if err != nil { + return err + } + + // Header lines + _, err = fmt.Fprintf(w, "Host: %s\r\n", host) + if err != nil { + return err + } + + // Use the defaultUserAgent unless the Header contains one, which + // may be blank to not send the header. + userAgent := defaultUserAgent + if req.Header != nil { + if ua := req.Header["User-Agent"]; len(ua) > 0 { + userAgent = ua[0] + } + } + if userAgent != "" { + _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) + if err != nil { + return err + } + } + + // Process Body,ContentLength,Close,Trailer + tw, err := newTransferWriter(req) + if err != nil { + return err + } + err = tw.WriteHeader(w) + if err != nil { + return err + } + + err = req.Header.WriteSubset(w, reqWriteExcludeHeader) + if err != nil { + return err + } + + if extraHeaders != nil { + err = extraHeaders.Write(w) + if err != nil { + return err + } + } + + _, err = io.WriteString(w, "\r\n") + if err != nil { + return err + } + + // Write body and trailer + err = tw.WriteBody(w) + if err != nil { + return err + } + + if bw != nil { + return bw.Flush() + } + return nil +} + +// ParseHTTPVersion parses a HTTP version string. +// "HTTP/1.0" returns (1, 0, true). +func ParseHTTPVersion(vers string) (major, minor int, ok bool) { + const Big = 1000000 // arbitrary upper bound + switch vers { + case "HTTP/1.1": + return 1, 1, true + case "HTTP/1.0": + return 1, 0, true + } + if !strings.HasPrefix(vers, "HTTP/") { + return 0, 0, false + } + dot := strings.Index(vers, ".") + if dot < 0 { + return 0, 0, false + } + major, err := strconv.Atoi(vers[5:dot]) + if err != nil || major < 0 || major > Big { + return 0, 0, false + } + minor, err = strconv.Atoi(vers[dot+1:]) + if err != nil || minor < 0 || minor > Big { + return 0, 0, false + } + return major, minor, true +} + +// NewRequest returns a new Request given a method, URL, and optional body. +// +// If the provided body is also an io.Closer, the returned +// Request.Body is set to body and will be closed by the Client +// methods Do, Post, and PostForm, and Transport.RoundTrip. +func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { + u, err := url.Parse(urlStr) + if err != nil { + return nil, err + } + rc, ok := body.(io.ReadCloser) + if !ok && body != nil { + rc = ioutil.NopCloser(body) + } + req := &Request{ + Method: method, + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + Body: rc, + Host: u.Host, + } + if body != nil { + switch v := body.(type) { + case *bytes.Buffer: + req.ContentLength = int64(v.Len()) + case *bytes.Reader: + req.ContentLength = int64(v.Len()) + case *strings.Reader: + req.ContentLength = int64(v.Len()) + } + } + + return req, nil +} + +// BasicAuth returns the username and password provided in the request's +// Authorization header, if the request uses HTTP Basic Authentication. +// See RFC 2617, Section 2. +func (r *Request) BasicAuth() (username, password string, ok bool) { + auth := r.Header.Get("Authorization") + if auth == "" { + return + } + return parseBasicAuth(auth) +} + +// parseBasicAuth parses an HTTP Basic Authentication string. +// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true). +func parseBasicAuth(auth string) (username, password string, ok bool) { + if !strings.HasPrefix(auth, "Basic ") { + return + } + c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic ")) + if err != nil { + return + } + cs := string(c) + s := strings.IndexByte(cs, ':') + if s < 0 { + return + } + return cs[:s], cs[s+1:], true +} + +// SetBasicAuth sets the request's Authorization header to use HTTP +// Basic Authentication with the provided username and password. +// +// With HTTP Basic Authentication the provided username and password +// are not encrypted. +func (r *Request) SetBasicAuth(username, password string) { + r.Header.Set("Authorization", "Basic "+basicAuth(username, password)) +} + +// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. +func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { + s1 := strings.Index(line, " ") + s2 := strings.Index(line[s1+1:], " ") + if s1 < 0 || s2 < 0 { + return + } + s2 += s1 + 1 + return line[:s1], line[s1+1 : s2], line[s2+1:], true +} + +var textprotoReaderPool sync.Pool + +func newTextprotoReader(br *bufio.Reader) *textproto.Reader { + if v := textprotoReaderPool.Get(); v != nil { + tr := v.(*textproto.Reader) + tr.R = br + return tr + } + return textproto.NewReader(br) +} + +func putTextprotoReader(r *textproto.Reader) { + r.R = nil + textprotoReaderPool.Put(r) +} + +// ReadRequest reads and parses a request from b. +func ReadRequest(b *bufio.Reader) (req *Request, err error) { + + tp := newTextprotoReader(b) + req = new(Request) + + // First line: GET /index.html HTTP/1.0 + var s string + if s, err = tp.ReadLine(); err != nil { + return nil, err + } + defer func() { + putTextprotoReader(tp) + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + + var ok bool + req.Method, req.RequestURI, req.Proto, ok = parseRequestLine(s) + if !ok { + return nil, &badStringError{"malformed HTTP request", s} + } + rawurl := req.RequestURI + if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok { + return nil, &badStringError{"malformed HTTP version", req.Proto} + } + + // CONNECT requests are used two different ways, and neither uses a full URL: + // The standard use is to tunnel HTTPS through an HTTP proxy. + // It looks like "CONNECT www.google.com:443 HTTP/1.1", and the parameter is + // just the authority section of a URL. This information should go in req.URL.Host. + // + // The net/rpc package also uses CONNECT, but there the parameter is a path + // that starts with a slash. It can be parsed with the regular URL parser, + // and the path will end up in req.URL.Path, where it needs to be in order for + // RPC to work. + justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/") + if justAuthority { + rawurl = "http://" + rawurl + } + + if req.URL, err = url.ParseRequestURI(rawurl); err != nil { + return nil, err + } + + if justAuthority { + // Strip the bogus "http://" back off. + req.URL.Scheme = "" + } + + // Subsequent lines: Key: value. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + return nil, err + } + req.Header = Header(mimeHeader) + + // RFC2616: Must treat + // GET /index.html HTTP/1.1 + // Host: www.google.com + // and + // GET http://www.google.com/index.html HTTP/1.1 + // Host: doesntmatter + // the same. In the second case, any Host line is ignored. + req.Host = req.URL.Host + if req.Host == "" { + req.Host = req.Header.get("Host") + } + delete(req.Header, "Host") + + fixPragmaCacheControl(req.Header) + + err = readTransfer(req, b) + if err != nil { + return nil, err + } + + req.Close = shouldClose(req.ProtoMajor, req.ProtoMinor, req.Header, false) + return req, nil +} + +// MaxBytesReader is similar to io.LimitReader but is intended for +// limiting the size of incoming request bodies. In contrast to +// io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a +// non-EOF error for a Read beyond the limit, and Closes the +// underlying reader when its Close method is called. +// +// MaxBytesReader prevents clients from accidentally or maliciously +// sending a large request and wasting server resources. +func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser { + return &maxBytesReader{w: w, r: r, n: n} +} + +type maxBytesReader struct { + w ResponseWriter + r io.ReadCloser // underlying reader + n int64 // max bytes remaining + stopped bool +} + +func (l *maxBytesReader) Read(p []byte) (n int, err error) { + if l.n <= 0 { + if !l.stopped { + l.stopped = true + if res, ok := l.w.(*response); ok { + res.requestTooLarge() + } + } + return 0, errors.New("http: request body too large") + } + if int64(len(p)) > l.n { + p = p[:l.n] + } + n, err = l.r.Read(p) + l.n -= int64(n) + return +} + +func (l *maxBytesReader) Close() error { + return l.r.Close() +} + +func copyValues(dst, src url.Values) { + for k, vs := range src { + for _, value := range vs { + dst.Add(k, value) + } + } +} + +func parsePostForm(r *Request) (vs url.Values, err error) { + if r.Body == nil { + err = errors.New("missing form body") + return + } + ct := r.Header.Get("Content-Type") + // RFC 2616, section 7.2.1 - empty type + // SHOULD be treated as application/octet-stream + if ct == "" { + ct = "application/octet-stream" + } + ct, _, err = mime.ParseMediaType(ct) + switch { + case ct == "application/x-www-form-urlencoded": + var reader io.Reader = r.Body + maxFormSize := int64(1<<63 - 1) + if _, ok := r.Body.(*maxBytesReader); !ok { + maxFormSize = int64(10 << 20) // 10 MB is a lot of text. + reader = io.LimitReader(r.Body, maxFormSize+1) + } + b, e := ioutil.ReadAll(reader) + if e != nil { + if err == nil { + err = e + } + break + } + if int64(len(b)) > maxFormSize { + err = errors.New("http: POST too large") + return + } + vs, e = url.ParseQuery(string(b)) + if err == nil { + err = e + } + case ct == "multipart/form-data": + // handled by ParseMultipartForm (which is calling us, or should be) + // TODO(bradfitz): there are too many possible + // orders to call too many functions here. + // Clean this up and write more tests. + // request_test.go contains the start of this, + // in TestParseMultipartFormOrder and others. + } + return +} + +// ParseForm parses the raw query from the URL and updates r.Form. +// +// For POST or PUT requests, it also parses the request body as a form and +// put the results into both r.PostForm and r.Form. +// POST and PUT body parameters take precedence over URL query string values +// in r.Form. +// +// If the request Body's size has not already been limited by MaxBytesReader, +// the size is capped at 10MB. +// +// ParseMultipartForm calls ParseForm automatically. +// It is idempotent. +func (r *Request) ParseForm() error { + var err error + if r.PostForm == nil { + if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" { + r.PostForm, err = parsePostForm(r) + } + if r.PostForm == nil { + r.PostForm = make(url.Values) + } + } + if r.Form == nil { + if len(r.PostForm) > 0 { + r.Form = make(url.Values) + copyValues(r.Form, r.PostForm) + } + var newValues url.Values + if r.URL != nil { + var e error + newValues, e = url.ParseQuery(r.URL.RawQuery) + if err == nil { + err = e + } + } + if newValues == nil { + newValues = make(url.Values) + } + if r.Form == nil { + r.Form = newValues + } else { + copyValues(r.Form, newValues) + } + } + return err +} + +// ParseMultipartForm parses a request body as multipart/form-data. +// The whole request body is parsed and up to a total of maxMemory bytes of +// its file parts are stored in memory, with the remainder stored on +// disk in temporary files. +// ParseMultipartForm calls ParseForm if necessary. +// After one call to ParseMultipartForm, subsequent calls have no effect. +func (r *Request) ParseMultipartForm(maxMemory int64) error { + if r.MultipartForm == multipartByReader { + return errors.New("http: multipart handled by MultipartReader") + } + if r.Form == nil { + err := r.ParseForm() + if err != nil { + return err + } + } + if r.MultipartForm != nil { + return nil + } + + mr, err := r.multipartReader() + if err != nil { + return err + } + + f, err := mr.ReadForm(maxMemory) + if err != nil { + return err + } + for k, v := range f.Value { + r.Form[k] = append(r.Form[k], v...) + } + r.MultipartForm = f + + return nil +} + +// FormValue returns the first value for the named component of the query. +// POST and PUT body parameters take precedence over URL query string values. +// FormValue calls ParseMultipartForm and ParseForm if necessary and ignores +// any errors returned by these functions. +// To access multiple values of the same key, call ParseForm and +// then inspect Request.Form directly. +func (r *Request) FormValue(key string) string { + if r.Form == nil { + r.ParseMultipartForm(defaultMaxMemory) + } + if vs := r.Form[key]; len(vs) > 0 { + return vs[0] + } + return "" +} + +// PostFormValue returns the first value for the named component of the POST +// or PUT request body. URL query parameters are ignored. +// PostFormValue calls ParseMultipartForm and ParseForm if necessary and ignores +// any errors returned by these functions. +func (r *Request) PostFormValue(key string) string { + if r.PostForm == nil { + r.ParseMultipartForm(defaultMaxMemory) + } + if vs := r.PostForm[key]; len(vs) > 0 { + return vs[0] + } + return "" +} + +// FormFile returns the first file for the provided form key. +// FormFile calls ParseMultipartForm and ParseForm if necessary. +func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) { + if r.MultipartForm == multipartByReader { + return nil, nil, errors.New("http: multipart handled by MultipartReader") + } + if r.MultipartForm == nil { + err := r.ParseMultipartForm(defaultMaxMemory) + if err != nil { + return nil, nil, err + } + } + if r.MultipartForm != nil && r.MultipartForm.File != nil { + if fhs := r.MultipartForm.File[key]; len(fhs) > 0 { + f, err := fhs[0].Open() + return f, fhs[0], err + } + } + return nil, nil, ErrMissingFile +} + +func (r *Request) expectsContinue() bool { + return hasToken(r.Header.get("Expect"), "100-continue") +} + +func (r *Request) wantsHttp10KeepAlive() bool { + if r.ProtoMajor != 1 || r.ProtoMinor != 0 { + return false + } + return hasToken(r.Header.get("Connection"), "keep-alive") +} + +func (r *Request) wantsClose() bool { + return hasToken(r.Header.get("Connection"), "close") +} + +func (r *Request) closeBody() { + if r.Body != nil { + r.Body.Close() + } +} diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go new file mode 100644 index 000000000..759ea4e8b --- /dev/null +++ b/src/net/http/request_test.go @@ -0,0 +1,680 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "bufio" + "bytes" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + "mime/multipart" + . "net/http" + "net/http/httptest" + "net/url" + "os" + "reflect" + "regexp" + "strings" + "testing" +) + +func TestQuery(t *testing.T) { + req := &Request{Method: "GET"} + req.URL, _ = url.Parse("http://www.google.com/search?q=foo&q=bar") + if q := req.FormValue("q"); q != "foo" { + t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) + } +} + +func TestPostQuery(t *testing.T) { + req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not", + strings.NewReader("z=post&both=y&prio=2&empty=")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + if q := req.FormValue("q"); q != "foo" { + t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) + } + if z := req.FormValue("z"); z != "post" { + t.Errorf(`req.FormValue("z") = %q, want "post"`, z) + } + if bq, found := req.PostForm["q"]; found { + t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq) + } + if bz := req.PostFormValue("z"); bz != "post" { + t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz) + } + if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) { + t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs) + } + if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) { + t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both) + } + if prio := req.FormValue("prio"); prio != "2" { + t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio) + } + if empty := req.FormValue("empty"); empty != "" { + t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty) + } +} + +func TestPatchQuery(t *testing.T) { + req, _ := NewRequest("PATCH", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not", + strings.NewReader("z=post&both=y&prio=2&empty=")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + if q := req.FormValue("q"); q != "foo" { + t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) + } + if z := req.FormValue("z"); z != "post" { + t.Errorf(`req.FormValue("z") = %q, want "post"`, z) + } + if bq, found := req.PostForm["q"]; found { + t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq) + } + if bz := req.PostFormValue("z"); bz != "post" { + t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz) + } + if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) { + t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs) + } + if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) { + t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both) + } + if prio := req.FormValue("prio"); prio != "2" { + t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio) + } + if empty := req.FormValue("empty"); empty != "" { + t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty) + } +} + +type stringMap map[string][]string +type parseContentTypeTest struct { + shouldError bool + contentType stringMap +} + +var parseContentTypeTests = []parseContentTypeTest{ + {false, stringMap{"Content-Type": {"text/plain"}}}, + // Empty content type is legal - shoult be treated as + // application/octet-stream (RFC 2616, section 7.2.1) + {false, stringMap{}}, + {true, stringMap{"Content-Type": {"text/plain; boundary="}}}, + {false, stringMap{"Content-Type": {"application/unknown"}}}, +} + +func TestParseFormUnknownContentType(t *testing.T) { + for i, test := range parseContentTypeTests { + req := &Request{ + Method: "POST", + Header: Header(test.contentType), + Body: ioutil.NopCloser(strings.NewReader("body")), + } + err := req.ParseForm() + switch { + case err == nil && test.shouldError: + t.Errorf("test %d should have returned error", i) + case err != nil && !test.shouldError: + t.Errorf("test %d should not have returned error, got %v", i, err) + } + } +} + +func TestParseFormInitializeOnError(t *testing.T) { + nilBody, _ := NewRequest("POST", "http://www.google.com/search?q=foo", nil) + tests := []*Request{ + nilBody, + {Method: "GET", URL: nil}, + } + for i, req := range tests { + err := req.ParseForm() + if req.Form == nil { + t.Errorf("%d. Form not initialized, error %v", i, err) + } + if req.PostForm == nil { + t.Errorf("%d. PostForm not initialized, error %v", i, err) + } + } +} + +func TestMultipartReader(t *testing.T) { + req := &Request{ + Method: "POST", + Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}}, + Body: ioutil.NopCloser(new(bytes.Buffer)), + } + multipart, err := req.MultipartReader() + if multipart == nil { + t.Errorf("expected multipart; error: %v", err) + } + + req.Header = Header{"Content-Type": {"text/plain"}} + multipart, err = req.MultipartReader() + if multipart != nil { + t.Error("unexpected multipart for text/plain") + } +} + +func TestParseMultipartForm(t *testing.T) { + req := &Request{ + Method: "POST", + Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}}, + Body: ioutil.NopCloser(new(bytes.Buffer)), + } + err := req.ParseMultipartForm(25) + if err == nil { + t.Error("expected multipart EOF, got nil") + } + + req.Header = Header{"Content-Type": {"text/plain"}} + err = req.ParseMultipartForm(25) + if err != ErrNotMultipart { + t.Error("expected ErrNotMultipart for text/plain") + } +} + +func TestRedirect(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + switch r.URL.Path { + case "/": + w.Header().Set("Location", "/foo/") + w.WriteHeader(StatusSeeOther) + case "/foo/": + fmt.Fprintf(w, "foo") + default: + w.WriteHeader(StatusBadRequest) + } + })) + defer ts.Close() + + var end = regexp.MustCompile("/foo/$") + r, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + r.Body.Close() + url := r.Request.URL.String() + if r.StatusCode != 200 || !end.MatchString(url) { + t.Fatalf("Get got status %d at %q, want 200 matching /foo/$", r.StatusCode, url) + } +} + +func TestSetBasicAuth(t *testing.T) { + r, _ := NewRequest("GET", "http://example.com/", nil) + r.SetBasicAuth("Aladdin", "open sesame") + if g, e := r.Header.Get("Authorization"), "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="; g != e { + t.Errorf("got header %q, want %q", g, e) + } +} + +func TestMultipartRequest(t *testing.T) { + // Test that we can read the values and files of a + // multipart request with FormValue and FormFile, + // and that ParseMultipartForm can be called multiple times. + req := newTestMultipartRequest(t) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatal("ParseMultipartForm first call:", err) + } + defer req.MultipartForm.RemoveAll() + validateTestMultipartContents(t, req, false) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatal("ParseMultipartForm second call:", err) + } + validateTestMultipartContents(t, req, false) +} + +func TestMultipartRequestAuto(t *testing.T) { + // Test that FormValue and FormFile automatically invoke + // ParseMultipartForm and return the right values. + req := newTestMultipartRequest(t) + defer func() { + if req.MultipartForm != nil { + req.MultipartForm.RemoveAll() + } + }() + validateTestMultipartContents(t, req, true) +} + +func TestMissingFileMultipartRequest(t *testing.T) { + // Test that FormFile returns an error if + // the named file is missing. + req := newTestMultipartRequest(t) + testMissingFile(t, req) +} + +// Test that FormValue invokes ParseMultipartForm. +func TestFormValueCallsParseMultipartForm(t *testing.T) { + req, _ := NewRequest("POST", "http://www.google.com/", strings.NewReader("z=post")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + if req.Form != nil { + t.Fatal("Unexpected request Form, want nil") + } + req.FormValue("z") + if req.Form == nil { + t.Fatal("ParseMultipartForm not called by FormValue") + } +} + +// Test that FormFile invokes ParseMultipartForm. +func TestFormFileCallsParseMultipartForm(t *testing.T) { + req := newTestMultipartRequest(t) + if req.Form != nil { + t.Fatal("Unexpected request Form, want nil") + } + req.FormFile("") + if req.Form == nil { + t.Fatal("ParseMultipartForm not called by FormFile") + } +} + +// Test that ParseMultipartForm errors if called +// after MultipartReader on the same request. +func TestParseMultipartFormOrder(t *testing.T) { + req := newTestMultipartRequest(t) + if _, err := req.MultipartReader(); err != nil { + t.Fatalf("MultipartReader: %v", err) + } + if err := req.ParseMultipartForm(1024); err == nil { + t.Fatal("expected an error from ParseMultipartForm after call to MultipartReader") + } +} + +// Test that MultipartReader errors if called +// after ParseMultipartForm on the same request. +func TestMultipartReaderOrder(t *testing.T) { + req := newTestMultipartRequest(t) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatalf("ParseMultipartForm: %v", err) + } + defer req.MultipartForm.RemoveAll() + if _, err := req.MultipartReader(); err == nil { + t.Fatal("expected an error from MultipartReader after call to ParseMultipartForm") + } +} + +// Test that FormFile errors if called after +// MultipartReader on the same request. +func TestFormFileOrder(t *testing.T) { + req := newTestMultipartRequest(t) + if _, err := req.MultipartReader(); err != nil { + t.Fatalf("MultipartReader: %v", err) + } + if _, _, err := req.FormFile(""); err == nil { + t.Fatal("expected an error from FormFile after call to MultipartReader") + } +} + +var readRequestErrorTests = []struct { + in string + err error +}{ + {"GET / HTTP/1.1\r\nheader:foo\r\n\r\n", nil}, + {"GET / HTTP/1.1\r\nheader:foo\r\n", io.ErrUnexpectedEOF}, + {"", io.EOF}, +} + +func TestReadRequestErrors(t *testing.T) { + for i, tt := range readRequestErrorTests { + _, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.in))) + if err != tt.err { + t.Errorf("%d. got error = %v; want %v", i, err, tt.err) + } + } +} + +func TestNewRequestHost(t *testing.T) { + req, err := NewRequest("GET", "http://localhost:1234/", nil) + if err != nil { + t.Fatal(err) + } + if req.Host != "localhost:1234" { + t.Errorf("Host = %q; want localhost:1234", req.Host) + } +} + +func TestNewRequestContentLength(t *testing.T) { + readByte := func(r io.Reader) io.Reader { + var b [1]byte + r.Read(b[:]) + return r + } + tests := []struct { + r io.Reader + want int64 + }{ + {bytes.NewReader([]byte("123")), 3}, + {bytes.NewBuffer([]byte("1234")), 4}, + {strings.NewReader("12345"), 5}, + // Not detected: + {struct{ io.Reader }{strings.NewReader("xyz")}, 0}, + {io.NewSectionReader(strings.NewReader("x"), 0, 6), 0}, + {readByte(io.NewSectionReader(strings.NewReader("xy"), 0, 6)), 0}, + } + for _, tt := range tests { + req, err := NewRequest("POST", "http://localhost/", tt.r) + if err != nil { + t.Fatal(err) + } + if req.ContentLength != tt.want { + t.Errorf("ContentLength(%T) = %d; want %d", tt.r, req.ContentLength, tt.want) + } + } +} + +var parseHTTPVersionTests = []struct { + vers string + major, minor int + ok bool +}{ + {"HTTP/0.9", 0, 9, true}, + {"HTTP/1.0", 1, 0, true}, + {"HTTP/1.1", 1, 1, true}, + {"HTTP/3.14", 3, 14, true}, + + {"HTTP", 0, 0, false}, + {"HTTP/one.one", 0, 0, false}, + {"HTTP/1.1/", 0, 0, false}, + {"HTTP/-1,0", 0, 0, false}, + {"HTTP/0,-1", 0, 0, false}, + {"HTTP/", 0, 0, false}, + {"HTTP/1,1", 0, 0, false}, +} + +func TestParseHTTPVersion(t *testing.T) { + for _, tt := range parseHTTPVersionTests { + major, minor, ok := ParseHTTPVersion(tt.vers) + if ok != tt.ok || major != tt.major || minor != tt.minor { + type version struct { + major, minor int + ok bool + } + t.Errorf("failed to parse %q, expected: %#v, got %#v", tt.vers, version{tt.major, tt.minor, tt.ok}, version{major, minor, ok}) + } + } +} + +type getBasicAuthTest struct { + username, password string + ok bool +} + +type parseBasicAuthTest getBasicAuthTest + +type basicAuthCredentialsTest struct { + username, password string +} + +var getBasicAuthTests = []struct { + username, password string + ok bool +}{ + {"Aladdin", "open sesame", true}, + {"Aladdin", "open:sesame", true}, + {"", "", true}, +} + +func TestGetBasicAuth(t *testing.T) { + for _, tt := range getBasicAuthTests { + r, _ := NewRequest("GET", "http://example.com/", nil) + r.SetBasicAuth(tt.username, tt.password) + username, password, ok := r.BasicAuth() + if ok != tt.ok || username != tt.username || password != tt.password { + t.Errorf("BasicAuth() = %#v, want %#v", getBasicAuthTest{username, password, ok}, + getBasicAuthTest{tt.username, tt.password, tt.ok}) + } + } + // Unauthenticated request. + r, _ := NewRequest("GET", "http://example.com/", nil) + username, password, ok := r.BasicAuth() + if ok { + t.Errorf("expected false from BasicAuth when the request is unauthenticated") + } + want := basicAuthCredentialsTest{"", ""} + if username != want.username || password != want.password { + t.Errorf("expected credentials: %#v when the request is unauthenticated, got %#v", + want, basicAuthCredentialsTest{username, password}) + } +} + +var parseBasicAuthTests = []struct { + header, username, password string + ok bool +}{ + {"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "Aladdin", "open sesame", true}, + {"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open:sesame")), "Aladdin", "open:sesame", true}, + {"Basic " + base64.StdEncoding.EncodeToString([]byte(":")), "", "", true}, + {"Basic" + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false}, + {base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false}, + {"Basic ", "", "", false}, + {"Basic Aladdin:open sesame", "", "", false}, + {`Digest username="Aladdin"`, "", "", false}, +} + +func TestParseBasicAuth(t *testing.T) { + for _, tt := range parseBasicAuthTests { + r, _ := NewRequest("GET", "http://example.com/", nil) + r.Header.Set("Authorization", tt.header) + username, password, ok := r.BasicAuth() + if ok != tt.ok || username != tt.username || password != tt.password { + t.Errorf("BasicAuth() = %#v, want %#v", getBasicAuthTest{username, password, ok}, + getBasicAuthTest{tt.username, tt.password, tt.ok}) + } + } +} + +type logWrites struct { + t *testing.T + dst *[]string +} + +func (l logWrites) WriteByte(c byte) error { + l.t.Fatalf("unexpected WriteByte call") + return nil +} + +func (l logWrites) Write(p []byte) (n int, err error) { + *l.dst = append(*l.dst, string(p)) + return len(p), nil +} + +func TestRequestWriteBufferedWriter(t *testing.T) { + got := []string{} + req, _ := NewRequest("GET", "http://foo.com/", nil) + req.Write(logWrites{t, &got}) + want := []string{ + "GET / HTTP/1.1\r\n", + "Host: foo.com\r\n", + "User-Agent: " + DefaultUserAgent + "\r\n", + "\r\n", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("Writes = %q\n Want = %q", got, want) + } +} + +func testMissingFile(t *testing.T, req *Request) { + f, fh, err := req.FormFile("missing") + if f != nil { + t.Errorf("FormFile file = %v, want nil", f) + } + if fh != nil { + t.Errorf("FormFile file header = %q, want nil", fh) + } + if err != ErrMissingFile { + t.Errorf("FormFile err = %q, want ErrMissingFile", err) + } +} + +func newTestMultipartRequest(t *testing.T) *Request { + b := strings.NewReader(strings.Replace(message, "\n", "\r\n", -1)) + req, err := NewRequest("POST", "/", b) + if err != nil { + t.Fatal("NewRequest:", err) + } + ctype := fmt.Sprintf(`multipart/form-data; boundary="%s"`, boundary) + req.Header.Set("Content-type", ctype) + return req +} + +func validateTestMultipartContents(t *testing.T, req *Request, allMem bool) { + if g, e := req.FormValue("texta"), textaValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + if g, e := req.FormValue("textb"), textbValue; g != e { + t.Errorf("textb value = %q, want %q", g, e) + } + if g := req.FormValue("missing"); g != "" { + t.Errorf("missing value = %q, want empty string", g) + } + + assertMem := func(n string, fd multipart.File) { + if _, ok := fd.(*os.File); ok { + t.Error(n, " is *os.File, should not be") + } + } + fda := testMultipartFile(t, req, "filea", "filea.txt", fileaContents) + defer fda.Close() + assertMem("filea", fda) + fdb := testMultipartFile(t, req, "fileb", "fileb.txt", filebContents) + defer fdb.Close() + if allMem { + assertMem("fileb", fdb) + } else { + if _, ok := fdb.(*os.File); !ok { + t.Errorf("fileb has unexpected underlying type %T", fdb) + } + } + + testMissingFile(t, req) +} + +func testMultipartFile(t *testing.T, req *Request, key, expectFilename, expectContent string) multipart.File { + f, fh, err := req.FormFile(key) + if err != nil { + t.Fatalf("FormFile(%q): %q", key, err) + } + if fh.Filename != expectFilename { + t.Errorf("filename = %q, want %q", fh.Filename, expectFilename) + } + var b bytes.Buffer + _, err = io.Copy(&b, f) + if err != nil { + t.Fatal("copying contents:", err) + } + if g := b.String(); g != expectContent { + t.Errorf("contents = %q, want %q", g, expectContent) + } + return f +} + +const ( + fileaContents = "This is a test file." + filebContents = "Another test file." + textaValue = "foo" + textbValue = "bar" + boundary = `MyBoundary` +) + +const message = ` +--MyBoundary +Content-Disposition: form-data; name="filea"; filename="filea.txt" +Content-Type: text/plain + +` + fileaContents + ` +--MyBoundary +Content-Disposition: form-data; name="fileb"; filename="fileb.txt" +Content-Type: text/plain + +` + filebContents + ` +--MyBoundary +Content-Disposition: form-data; name="texta" + +` + textaValue + ` +--MyBoundary +Content-Disposition: form-data; name="textb" + +` + textbValue + ` +--MyBoundary-- +` + +func benchmarkReadRequest(b *testing.B, request string) { + request = request + "\n" // final \n + request = strings.Replace(request, "\n", "\r\n", -1) // expand \n to \r\n + b.SetBytes(int64(len(request))) + r := bufio.NewReader(&infiniteReader{buf: []byte(request)}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := ReadRequest(r) + if err != nil { + b.Fatalf("failed to read request: %v", err) + } + } +} + +// infiniteReader satisfies Read requests as if the contents of buf +// loop indefinitely. +type infiniteReader struct { + buf []byte + offset int +} + +func (r *infiniteReader) Read(b []byte) (int, error) { + n := copy(b, r.buf[r.offset:]) + r.offset = (r.offset + n) % len(r.buf) + return n, nil +} + +func BenchmarkReadRequestChrome(b *testing.B) { + // https://github.com/felixge/node-http-perf/blob/master/fixtures/get.http + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +Connection: keep-alive +Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 +User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 +Cookie: __utma=1.1978842379.1323102373.1323102373.1323102373.1; EPi:NumberOfVisits=1,2012-02-28T13:42:18; CrmSession=5b707226b9563e1bc69084d07a107c98; plushContainerWidth=100%25; plushNoTopMenu=0; hudson_auto_refresh=false +`) +} + +func BenchmarkReadRequestCurl(b *testing.B) { + // curl http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +User-Agent: curl/7.27.0 +Host: localhost:8080 +Accept: */* +`) +} + +func BenchmarkReadRequestApachebench(b *testing.B) { + // ab -n 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.0 +Host: localhost:8080 +User-Agent: ApacheBench/2.3 +Accept: */* +`) +} + +func BenchmarkReadRequestSiege(b *testing.B) { + // siege -r 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +Accept: */* +Accept-Encoding: gzip +User-Agent: JoeDog/1.00 [en] (X11; I; Siege 2.70) +Connection: keep-alive +`) +} + +func BenchmarkReadRequestWrk(b *testing.B) { + // wrk -t 1 -r 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +`) +} diff --git a/src/net/http/requestwrite_test.go b/src/net/http/requestwrite_test.go new file mode 100644 index 000000000..7a6bd5878 --- /dev/null +++ b/src/net/http/requestwrite_test.go @@ -0,0 +1,623 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net/url" + "strings" + "testing" +) + +type reqWriteTest struct { + Req Request + Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body + + // Any of these three may be empty to skip that test. + WantWrite string // Request.Write + WantProxy string // Request.WriteProxy + + WantError error // wanted error from Request.Write +} + +var reqWriteTests = []reqWriteTest{ + // HTTP/1.1 => chunked coding; no body; no trailer + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.techcrunch.com", + Path: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}, + "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"}, + "Accept-Encoding": {"gzip,deflate"}, + "Accept-Language": {"en-us,en;q=0.5"}, + "Keep-Alive": {"300"}, + "Proxy-Connection": {"keep-alive"}, + "User-Agent": {"Fake"}, + }, + Body: nil, + Close: false, + Host: "www.techcrunch.com", + Form: map[string][]string{}, + }, + + WantWrite: "GET / HTTP/1.1\r\n" + + "Host: www.techcrunch.com\r\n" + + "User-Agent: Fake\r\n" + + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" + + "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" + + "Accept-Encoding: gzip,deflate\r\n" + + "Accept-Language: en-us,en;q=0.5\r\n" + + "Keep-Alive: 300\r\n" + + "Proxy-Connection: keep-alive\r\n\r\n", + + WantProxy: "GET http://www.techcrunch.com/ HTTP/1.1\r\n" + + "Host: www.techcrunch.com\r\n" + + "User-Agent: Fake\r\n" + + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" + + "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" + + "Accept-Encoding: gzip,deflate\r\n" + + "Accept-Language: en-us,en;q=0.5\r\n" + + "Keep-Alive: 300\r\n" + + "Proxy-Connection: keep-alive\r\n\r\n", + }, + // HTTP/1.1 => chunked coding; body; empty trailer + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + TransferEncoding: []string{"chunked"}, + }, + + Body: []byte("abcdef"), + + WantWrite: "GET /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("abcdef") + chunk(""), + + WantProxy: "GET http://www.google.com/search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("abcdef") + chunk(""), + }, + // HTTP/1.1 POST => chunked coding; body; empty trailer + { + Req: Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: true, + TransferEncoding: []string{"chunked"}, + }, + + Body: []byte("abcdef"), + + WantWrite: "POST /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Connection: close\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("abcdef") + chunk(""), + + WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Connection: close\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("abcdef") + chunk(""), + }, + + // HTTP/1.1 POST with Content-Length, no chunking + { + Req: Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: true, + ContentLength: 6, + }, + + Body: []byte("abcdef"), + + WantWrite: "POST /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Connection: close\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + + WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Connection: close\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + }, + + // HTTP/1.1 POST with Content-Length in headers + { + Req: Request{ + Method: "POST", + URL: mustParseURL("http://example.com/"), + Host: "example.com", + Header: Header{ + "Content-Length": []string{"10"}, // ignored + }, + ContentLength: 6, + }, + + Body: []byte("abcdef"), + + WantWrite: "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + + WantProxy: "POST http://example.com/ HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + }, + + // default to HTTP/1.1 + { + Req: Request{ + Method: "GET", + URL: mustParseURL("/search"), + Host: "www.google.com", + }, + + WantWrite: "GET /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "\r\n", + }, + + // Request with a 0 ContentLength and a 0 byte body. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + + Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) }, + + // RFC 2616 Section 14.13 says Content-Length should be specified + // unless body is prohibited by the request method. + // Also, nginx expects it for POST and PUT. + WantWrite: "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + + WantProxy: "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + }, + + // Request with a 0 ContentLength and a 1 byte body. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + + Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 1)) }, + + WantWrite: "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("x") + chunk(""), + + WantProxy: "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("x") + chunk(""), + }, + + // Request with a ContentLength of 10 but a 5 byte body. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 10, // but we're going to send only 5 bytes + }, + Body: []byte("12345"), + WantError: errors.New("http: ContentLength=10 with Body length 5"), + }, + + // Request with a ContentLength of 4 but an 8 byte body. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 4, // but we're going to try to send 8 bytes + }, + Body: []byte("12345678"), + WantError: errors.New("http: ContentLength=4 with Body length 8"), + }, + + // Request with a 5 ContentLength and nil body. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 5, // but we'll omit the body + }, + WantError: errors.New("http: Request.ContentLength=5 with nil Body"), + }, + + // Request with a 0 ContentLength and a body with 1 byte content and an error. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + + Body: func() io.ReadCloser { + err := errors.New("Custom reader error") + errReader := &errorReader{err} + return ioutil.NopCloser(io.MultiReader(strings.NewReader("x"), errReader)) + }, + + WantError: errors.New("Custom reader error"), + }, + + // Request with a 0 ContentLength and a body without content and an error. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + + Body: func() io.ReadCloser { + err := errors.New("Custom reader error") + errReader := &errorReader{err} + return ioutil.NopCloser(errReader) + }, + + WantError: errors.New("Custom reader error"), + }, + + // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host, + // and doesn't add a User-Agent. + { + Req: Request{ + Method: "GET", + URL: mustParseURL("/foo"), + ProtoMajor: 1, + ProtoMinor: 0, + Header: Header{ + "X-Foo": []string{"X-Bar"}, + }, + }, + + WantWrite: "GET /foo HTTP/1.1\r\n" + + "Host: \r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "X-Foo: X-Bar\r\n\r\n", + }, + + // If no Request.Host and no Request.URL.Host, we send + // an empty Host header, and don't use + // Request.Header["Host"]. This is just testing that + // we don't change Go 1.0 behavior. + { + Req: Request{ + Method: "GET", + Host: "", + URL: &url.URL{ + Scheme: "http", + Host: "", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Host": []string{"bad.example.com"}, + }, + }, + + WantWrite: "GET /search HTTP/1.1\r\n" + + "Host: \r\n" + + "User-Agent: Go 1.1 package http\r\n\r\n", + }, + + // Opaque test #1 from golang.org/issue/4860 + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Opaque: "/%2F/%2F/", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + }, + + WantWrite: "GET /%2F/%2F/ HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n\r\n", + }, + + // Opaque test #2 from golang.org/issue/4860 + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "x.google.com", + Opaque: "//y.google.com/%2F/%2F/", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + }, + + WantWrite: "GET http://y.google.com/%2F/%2F/ HTTP/1.1\r\n" + + "Host: x.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n\r\n", + }, + + // Testing custom case in header keys. Issue 5022. + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "ALL-CAPS": {"x"}, + }, + }, + + WantWrite: "GET / HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "ALL-CAPS: x\r\n" + + "\r\n", + }, +} + +func TestRequestWrite(t *testing.T) { + for i := range reqWriteTests { + tt := &reqWriteTests[i] + + setBody := func() { + if tt.Body == nil { + return + } + switch b := tt.Body.(type) { + case []byte: + tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b)) + case func() io.ReadCloser: + tt.Req.Body = b() + } + } + setBody() + if tt.Req.Header == nil { + tt.Req.Header = make(Header) + } + + var braw bytes.Buffer + err := tt.Req.Write(&braw) + if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.WantError); g != e { + t.Errorf("writing #%d, err = %q, want %q", i, g, e) + continue + } + if err != nil { + continue + } + + if tt.WantWrite != "" { + sraw := braw.String() + if sraw != tt.WantWrite { + t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantWrite, sraw) + continue + } + } + + if tt.WantProxy != "" { + setBody() + var praw bytes.Buffer + err = tt.Req.WriteProxy(&praw) + if err != nil { + t.Errorf("WriteProxy #%d: %s", i, err) + continue + } + sraw := praw.String() + if sraw != tt.WantProxy { + t.Errorf("Test Proxy %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantProxy, sraw) + continue + } + } + } +} + +type closeChecker struct { + io.Reader + closed bool +} + +func (rc *closeChecker) Close() error { + rc.closed = true + return nil +} + +// TestRequestWriteClosesBody tests that Request.Write does close its request.Body. +// It also indirectly tests NewRequest and that it doesn't wrap an existing Closer +// inside a NopCloser, and that it serializes it correctly. +func TestRequestWriteClosesBody(t *testing.T) { + rc := &closeChecker{Reader: strings.NewReader("my body")} + req, _ := NewRequest("POST", "http://foo.com/", rc) + if req.ContentLength != 0 { + t.Errorf("got req.ContentLength %d, want 0", req.ContentLength) + } + buf := new(bytes.Buffer) + req.Write(buf) + if !rc.closed { + t.Error("body not closed after write") + } + expected := "POST / HTTP/1.1\r\n" + + "Host: foo.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + // TODO: currently we don't buffer before chunking, so we get a + // single "m" chunk before the other chunks, as this was the 1-byte + // read from our MultiReader where we stiched the Body back together + // after sniffing whether the Body was 0 bytes or not. + chunk("m") + + chunk("y body") + + chunk("") + if buf.String() != expected { + t.Errorf("write:\n got: %s\nwant: %s", buf.String(), expected) + } +} + +func chunk(s string) string { + return fmt.Sprintf("%x\r\n%s\r\n", len(s), s) +} + +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(fmt.Sprintf("Error parsing URL %q: %v", s, err)) + } + return u +} + +type writerFunc func([]byte) (int, error) + +func (f writerFunc) Write(p []byte) (int, error) { return f(p) } + +// TestRequestWriteError tests the Write err != nil checks in (*Request).write. +func TestRequestWriteError(t *testing.T) { + failAfter, writeCount := 0, 0 + errFail := errors.New("fake write failure") + + // w is the buffered io.Writer to write the request to. It + // fails exactly once on its Nth Write call, as controlled by + // failAfter. It also tracks the number of calls in + // writeCount. + w := struct { + io.ByteWriter // to avoid being wrapped by a bufio.Writer + io.Writer + }{ + nil, + writerFunc(func(p []byte) (n int, err error) { + writeCount++ + if failAfter == 0 { + err = errFail + } + failAfter-- + return len(p), err + }), + } + + req, _ := NewRequest("GET", "http://example.com/", nil) + const writeCalls = 4 // number of Write calls in current implementation + sawGood := false + for n := 0; n <= writeCalls+2; n++ { + failAfter = n + writeCount = 0 + err := req.Write(w) + var wantErr error + if n < writeCalls { + wantErr = errFail + } + if err != wantErr { + t.Errorf("for fail-after %d Writes, err = %v; want %v", n, err, wantErr) + continue + } + if err == nil { + sawGood = true + if writeCount != writeCalls { + t.Fatalf("writeCalls constant is outdated in test") + } + } + if writeCount > writeCalls || writeCount > n+1 { + t.Errorf("for fail-after %d, saw unexpectedly high (%d) write calls", n, writeCount) + } + } + if !sawGood { + t.Fatalf("writeCalls constant is outdated in test") + } +} diff --git a/src/net/http/response.go b/src/net/http/response.go new file mode 100644 index 000000000..5d2c39080 --- /dev/null +++ b/src/net/http/response.go @@ -0,0 +1,291 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP Response reading and parsing. + +package http + +import ( + "bufio" + "bytes" + "crypto/tls" + "errors" + "io" + "net/textproto" + "net/url" + "strconv" + "strings" +) + +var respExcludeHeader = map[string]bool{ + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + +// Response represents the response from an HTTP request. +// +type Response struct { + Status string // e.g. "200 OK" + StatusCode int // e.g. 200 + Proto string // e.g. "HTTP/1.0" + ProtoMajor int // e.g. 1 + ProtoMinor int // e.g. 0 + + // Header maps header keys to values. If the response had multiple + // headers with the same key, they may be concatenated, with comma + // delimiters. (Section 4.2 of RFC 2616 requires that multiple headers + // be semantically equivalent to a comma-delimited sequence.) Values + // duplicated by other fields in this struct (e.g., ContentLength) are + // omitted from Header. + // + // Keys in the map are canonicalized (see CanonicalHeaderKey). + Header Header + + // Body represents the response body. + // + // The http Client and Transport guarantee that Body is always + // non-nil, even on responses without a body or responses with + // a zero-length body. It is the caller's responsibility to + // close Body. + // + // The Body is automatically dechunked if the server replied + // with a "chunked" Transfer-Encoding. + Body io.ReadCloser + + // ContentLength records the length of the associated content. The + // value -1 indicates that the length is unknown. Unless Request.Method + // is "HEAD", values >= 0 indicate that the given number of bytes may + // be read from Body. + ContentLength int64 + + // Contains transfer encodings from outer-most to inner-most. Value is + // nil, means that "identity" encoding is used. + TransferEncoding []string + + // Close records whether the header directed that the connection be + // closed after reading Body. The value is advice for clients: neither + // ReadResponse nor Response.Write ever closes a connection. + Close bool + + // Trailer maps trailer keys to values, in the same + // format as the header. + Trailer Header + + // The Request that was sent to obtain this Response. + // Request's Body is nil (having already been consumed). + // This is only populated for Client requests. + Request *Request + + // TLS contains information about the TLS connection on which the + // response was received. It is nil for unencrypted responses. + // The pointer is shared between responses and should not be + // modified. + TLS *tls.ConnectionState +} + +// Cookies parses and returns the cookies set in the Set-Cookie headers. +func (r *Response) Cookies() []*Cookie { + return readSetCookies(r.Header) +} + +var ErrNoLocation = errors.New("http: no Location header in response") + +// Location returns the URL of the response's "Location" header, +// if present. Relative redirects are resolved relative to +// the Response's Request. ErrNoLocation is returned if no +// Location header is present. +func (r *Response) Location() (*url.URL, error) { + lv := r.Header.Get("Location") + if lv == "" { + return nil, ErrNoLocation + } + if r.Request != nil && r.Request.URL != nil { + return r.Request.URL.Parse(lv) + } + return url.Parse(lv) +} + +// ReadResponse reads and returns an HTTP response from r. +// The req parameter optionally specifies the Request that corresponds +// to this Response. If nil, a GET request is assumed. +// Clients must call resp.Body.Close when finished reading resp.Body. +// After that call, clients can inspect resp.Trailer to find key/value +// pairs included in the response trailer. +func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) { + tp := textproto.NewReader(r) + resp := &Response{ + Request: req, + } + + // Parse the first line of the response. + line, err := tp.ReadLine() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + f := strings.SplitN(line, " ", 3) + if len(f) < 2 { + return nil, &badStringError{"malformed HTTP response", line} + } + reasonPhrase := "" + if len(f) > 2 { + reasonPhrase = f[2] + } + resp.Status = f[1] + " " + reasonPhrase + resp.StatusCode, err = strconv.Atoi(f[1]) + if err != nil { + return nil, &badStringError{"malformed HTTP status code", f[1]} + } + + resp.Proto = f[0] + var ok bool + if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok { + return nil, &badStringError{"malformed HTTP version", resp.Proto} + } + + // Parse the response headers. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + resp.Header = Header(mimeHeader) + + fixPragmaCacheControl(resp.Header) + + err = readTransfer(resp, r) + if err != nil { + return nil, err + } + + return resp, nil +} + +// RFC2616: Should treat +// Pragma: no-cache +// like +// Cache-Control: no-cache +func fixPragmaCacheControl(header Header) { + if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { + if _, presentcc := header["Cache-Control"]; !presentcc { + header["Cache-Control"] = []string{"no-cache"} + } + } +} + +// ProtoAtLeast reports whether the HTTP protocol used +// in the response is at least major.minor. +func (r *Response) ProtoAtLeast(major, minor int) bool { + return r.ProtoMajor > major || + r.ProtoMajor == major && r.ProtoMinor >= minor +} + +// Writes the response (header, body and trailer) in wire format. This method +// consults the following fields of the response: +// +// StatusCode +// ProtoMajor +// ProtoMinor +// Request.Method +// TransferEncoding +// Trailer +// Body +// ContentLength +// Header, values for non-canonical keys will have unpredictable behavior +// +// Body is closed after it is sent. +func (r *Response) Write(w io.Writer) error { + // Status line + text := r.Status + if text == "" { + var ok bool + text, ok = statusText[r.StatusCode] + if !ok { + text = "status code " + strconv.Itoa(r.StatusCode) + } + } + protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor) + statusCode := strconv.Itoa(r.StatusCode) + " " + text = strings.TrimPrefix(text, statusCode) + if _, err := io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n"); err != nil { + return err + } + + // Clone it, so we can modify r1 as needed. + r1 := new(Response) + *r1 = *r + if r1.ContentLength == 0 && r1.Body != nil { + // Is it actually 0 length? Or just unknown? + var buf [1]byte + n, err := r1.Body.Read(buf[:]) + if err != nil && err != io.EOF { + return err + } + if n == 0 { + // Reset it to a known zero reader, in case underlying one + // is unhappy being read repeatedly. + r1.Body = eofReader + } else { + r1.ContentLength = -1 + r1.Body = struct { + io.Reader + io.Closer + }{ + io.MultiReader(bytes.NewReader(buf[:1]), r.Body), + r.Body, + } + } + } + // If we're sending a non-chunked HTTP/1.1 response without a + // content-length, the only way to do that is the old HTTP/1.0 + // way, by noting the EOF with a connection close, so we need + // to set Close. + if r1.ContentLength == -1 && !r1.Close && r1.ProtoAtLeast(1, 1) && !chunked(r1.TransferEncoding) { + r1.Close = true + } + + // Process Body,ContentLength,Close,Trailer + tw, err := newTransferWriter(r1) + if err != nil { + return err + } + err = tw.WriteHeader(w) + if err != nil { + return err + } + + // Rest of header + err = r.Header.WriteSubset(w, respExcludeHeader) + if err != nil { + return err + } + + // contentLengthAlreadySent may have been already sent for + // POST/PUT requests, even if zero length. See Issue 8180. + contentLengthAlreadySent := tw.shouldSendContentLength() + if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent { + if _, err := io.WriteString(w, "Content-Length: 0\r\n"); err != nil { + return err + } + } + + // End-of-header + if _, err := io.WriteString(w, "\r\n"); err != nil { + return err + } + + // Write body and trailer + err = tw.WriteBody(w) + if err != nil { + return err + } + + // Success + return nil +} diff --git a/src/net/http/response_test.go b/src/net/http/response_test.go new file mode 100644 index 000000000..06e940d9a --- /dev/null +++ b/src/net/http/response_test.go @@ -0,0 +1,674 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bufio" + "bytes" + "compress/gzip" + "crypto/rand" + "fmt" + "io" + "io/ioutil" + "net/http/internal" + "net/url" + "reflect" + "regexp" + "strings" + "testing" +) + +type respTest struct { + Raw string + Resp Response + Body string +} + +func dummyReq(method string) *Request { + return &Request{Method: method} +} + +func dummyReq11(method string) *Request { + return &Request{Method: method, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1} +} + +var respTests = []respTest{ + // Unchunked response without Content-Length. + { + "HTTP/1.0 200 OK\r\n" + + "Connection: close\r\n" + + "\r\n" + + "Body here\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{ + "Connection": {"close"}, // TODO(rsc): Delete? + }, + Close: true, + ContentLength: -1, + }, + + "Body here\n", + }, + + // Unchunked HTTP/1.1 response without Content-Length or + // Connection headers. + { + "HTTP/1.1 200 OK\r\n" + + "\r\n" + + "Body here\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Request: dummyReq("GET"), + Close: true, + ContentLength: -1, + }, + + "Body here\n", + }, + + // Unchunked HTTP/1.1 204 response without Content-Length. + { + "HTTP/1.1 204 No Content\r\n" + + "\r\n" + + "Body should not be read!\n", + + Response{ + Status: "204 No Content", + StatusCode: 204, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Request: dummyReq("GET"), + Close: false, + ContentLength: 0, + }, + + "", + }, + + // Unchunked response with Content-Length. + { + "HTTP/1.0 200 OK\r\n" + + "Content-Length: 10\r\n" + + "Connection: close\r\n" + + "\r\n" + + "Body here\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{ + "Connection": {"close"}, + "Content-Length": {"10"}, + }, + Close: true, + ContentLength: 10, + }, + + "Body here\n", + }, + + // Chunked response without Content-Length. + { + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "0a\r\n" + + "Body here\n\r\n" + + "09\r\n" + + "continued\r\n" + + "0\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{}, + Close: false, + ContentLength: -1, + TransferEncoding: []string{"chunked"}, + }, + + "Body here\ncontinued", + }, + + // Chunked response with Content-Length. + { + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "Content-Length: 10\r\n" + + "\r\n" + + "0a\r\n" + + "Body here\n\r\n" + + "0\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{}, + Close: false, + ContentLength: -1, + TransferEncoding: []string{"chunked"}, + }, + + "Body here\n", + }, + + // Chunked response in response to a HEAD request + { + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("HEAD"), + Header: Header{}, + TransferEncoding: []string{"chunked"}, + Close: false, + ContentLength: -1, + }, + + "", + }, + + // Content-Length in response to a HEAD request + { + "HTTP/1.0 200 OK\r\n" + + "Content-Length: 256\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("HEAD"), + Header: Header{"Content-Length": {"256"}}, + TransferEncoding: nil, + Close: true, + ContentLength: 256, + }, + + "", + }, + + // Content-Length in response to a HEAD request with HTTP/1.1 + { + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 256\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("HEAD"), + Header: Header{"Content-Length": {"256"}}, + TransferEncoding: nil, + Close: false, + ContentLength: 256, + }, + + "", + }, + + // No Content-Length or Chunked in response to a HEAD request + { + "HTTP/1.0 200 OK\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("HEAD"), + Header: Header{}, + TransferEncoding: nil, + Close: true, + ContentLength: -1, + }, + + "", + }, + + // explicit Content-Length of 0. + { + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{ + "Content-Length": {"0"}, + }, + Close: false, + ContentLength: 0, + }, + + "", + }, + + // Status line without a Reason-Phrase, but trailing space. + // (permitted by RFC 2616) + { + "HTTP/1.0 303 \r\n\r\n", + Response{ + Status: "303 ", + StatusCode: 303, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Close: true, + ContentLength: -1, + }, + + "", + }, + + // Status line without a Reason-Phrase, and no trailing space. + // (not permitted by RFC 2616, but we'll accept it anyway) + { + "HTTP/1.0 303\r\n\r\n", + Response{ + Status: "303 ", + StatusCode: 303, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Close: true, + ContentLength: -1, + }, + + "", + }, + + // golang.org/issue/4767: don't special-case multipart/byteranges responses + { + `HTTP/1.1 206 Partial Content +Connection: close +Content-Type: multipart/byteranges; boundary=18a75608c8f47cef + +some body`, + Response{ + Status: "206 Partial Content", + StatusCode: 206, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{ + "Content-Type": []string{"multipart/byteranges; boundary=18a75608c8f47cef"}, + }, + Close: true, + ContentLength: -1, + }, + + "some body", + }, + + // Unchunked response without Content-Length, Request is nil + { + "HTTP/1.0 200 OK\r\n" + + "Connection: close\r\n" + + "\r\n" + + "Body here\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Header: Header{ + "Connection": {"close"}, // TODO(rsc): Delete? + }, + Close: true, + ContentLength: -1, + }, + + "Body here\n", + }, + + // 206 Partial Content. golang.org/issue/8923 + { + "HTTP/1.1 206 Partial Content\r\n" + + "Content-Type: text/plain; charset=utf-8\r\n" + + "Accept-Ranges: bytes\r\n" + + "Content-Range: bytes 0-5/1862\r\n" + + "Content-Length: 6\r\n\r\n" + + "foobar", + + Response{ + Status: "206 Partial Content", + StatusCode: 206, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{ + "Accept-Ranges": []string{"bytes"}, + "Content-Length": []string{"6"}, + "Content-Type": []string{"text/plain; charset=utf-8"}, + "Content-Range": []string{"bytes 0-5/1862"}, + }, + ContentLength: 6, + }, + + "foobar", + }, +} + +func TestReadResponse(t *testing.T) { + for i, tt := range respTests { + resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } + rbody := resp.Body + resp.Body = nil + diff(t, fmt.Sprintf("#%d Response", i), resp, &tt.Resp) + var bout bytes.Buffer + if rbody != nil { + _, err = io.Copy(&bout, rbody) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } + rbody.Close() + } + body := bout.String() + if body != tt.Body { + t.Errorf("#%d: Body = %q want %q", i, body, tt.Body) + } + } +} + +func TestWriteResponse(t *testing.T) { + for i, tt := range respTests { + resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } + err = resp.Write(ioutil.Discard) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } + } +} + +var readResponseCloseInMiddleTests = []struct { + chunked, compressed bool +}{ + {false, false}, + {true, false}, + {true, true}, +} + +// TestReadResponseCloseInMiddle tests that closing a body after +// reading only part of its contents advances the read to the end of +// the request, right up until the next request. +func TestReadResponseCloseInMiddle(t *testing.T) { + for _, test := range readResponseCloseInMiddleTests { + fatalf := func(format string, args ...interface{}) { + args = append([]interface{}{test.chunked, test.compressed}, args...) + t.Fatalf("on test chunked=%v, compressed=%v: "+format, args...) + } + checkErr := func(err error, msg string) { + if err == nil { + return + } + fatalf(msg+": %v", err) + } + var buf bytes.Buffer + buf.WriteString("HTTP/1.1 200 OK\r\n") + if test.chunked { + buf.WriteString("Transfer-Encoding: chunked\r\n") + } else { + buf.WriteString("Content-Length: 1000000\r\n") + } + var wr io.Writer = &buf + if test.chunked { + wr = internal.NewChunkedWriter(wr) + } + if test.compressed { + buf.WriteString("Content-Encoding: gzip\r\n") + wr = gzip.NewWriter(wr) + } + buf.WriteString("\r\n") + + chunk := bytes.Repeat([]byte{'x'}, 1000) + for i := 0; i < 1000; i++ { + if test.compressed { + // Otherwise this compresses too well. + _, err := io.ReadFull(rand.Reader, chunk) + checkErr(err, "rand.Reader ReadFull") + } + wr.Write(chunk) + } + if test.compressed { + err := wr.(*gzip.Writer).Close() + checkErr(err, "compressor close") + } + if test.chunked { + buf.WriteString("0\r\n\r\n") + } + buf.WriteString("Next Request Here") + + bufr := bufio.NewReader(&buf) + resp, err := ReadResponse(bufr, dummyReq("GET")) + checkErr(err, "ReadResponse") + expectedLength := int64(-1) + if !test.chunked { + expectedLength = 1000000 + } + if resp.ContentLength != expectedLength { + fatalf("expected response length %d, got %d", expectedLength, resp.ContentLength) + } + if resp.Body == nil { + fatalf("nil body") + } + if test.compressed { + gzReader, err := gzip.NewReader(resp.Body) + checkErr(err, "gzip.NewReader") + resp.Body = &readerAndCloser{gzReader, resp.Body} + } + + rbuf := make([]byte, 2500) + n, err := io.ReadFull(resp.Body, rbuf) + checkErr(err, "2500 byte ReadFull") + if n != 2500 { + fatalf("ReadFull only read %d bytes", n) + } + if test.compressed == false && !bytes.Equal(bytes.Repeat([]byte{'x'}, 2500), rbuf) { + fatalf("ReadFull didn't read 2500 'x'; got %q", string(rbuf)) + } + resp.Body.Close() + + rest, err := ioutil.ReadAll(bufr) + checkErr(err, "ReadAll on remainder") + if e, g := "Next Request Here", string(rest); e != g { + g = regexp.MustCompile(`(xx+)`).ReplaceAllStringFunc(g, func(match string) string { + return fmt.Sprintf("x(repeated x%d)", len(match)) + }) + fatalf("remainder = %q, expected %q", g, e) + } + } +} + +func diff(t *testing.T, prefix string, have, want interface{}) { + hv := reflect.ValueOf(have).Elem() + wv := reflect.ValueOf(want).Elem() + if hv.Type() != wv.Type() { + t.Errorf("%s: type mismatch %v want %v", prefix, hv.Type(), wv.Type()) + } + for i := 0; i < hv.NumField(); i++ { + hf := hv.Field(i).Interface() + wf := wv.Field(i).Interface() + if !reflect.DeepEqual(hf, wf) { + t.Errorf("%s: %s = %v want %v", prefix, hv.Type().Field(i).Name, hf, wf) + } + } +} + +type responseLocationTest struct { + location string // Response's Location header or "" + requrl string // Response.Request.URL or "" + want string + wantErr error +} + +var responseLocationTests = []responseLocationTest{ + {"/foo", "http://bar.com/baz", "http://bar.com/foo", nil}, + {"http://foo.com/", "http://bar.com/baz", "http://foo.com/", nil}, + {"", "http://bar.com/baz", "", ErrNoLocation}, +} + +func TestLocationResponse(t *testing.T) { + for i, tt := range responseLocationTests { + res := new(Response) + res.Header = make(Header) + res.Header.Set("Location", tt.location) + if tt.requrl != "" { + res.Request = &Request{} + var err error + res.Request.URL, err = url.Parse(tt.requrl) + if err != nil { + t.Fatalf("bad test URL %q: %v", tt.requrl, err) + } + } + + got, err := res.Location() + if tt.wantErr != nil { + if err == nil { + t.Errorf("%d. err=nil; want %q", i, tt.wantErr) + continue + } + if g, e := err.Error(), tt.wantErr.Error(); g != e { + t.Errorf("%d. err=%q; want %q", i, g, e) + continue + } + continue + } + if err != nil { + t.Errorf("%d. err=%q", i, err) + continue + } + if g, e := got.String(), tt.want; g != e { + t.Errorf("%d. Location=%q; want %q", i, g, e) + } + } +} + +func TestResponseStatusStutter(t *testing.T) { + r := &Response{ + Status: "123 some status", + StatusCode: 123, + ProtoMajor: 1, + ProtoMinor: 3, + } + var buf bytes.Buffer + r.Write(&buf) + if strings.Contains(buf.String(), "123 123") { + t.Errorf("stutter in status: %s", buf.String()) + } +} + +func TestResponseContentLengthShortBody(t *testing.T) { + const shortBody = "Short body, not 123 bytes." + br := bufio.NewReader(strings.NewReader("HTTP/1.1 200 OK\r\n" + + "Content-Length: 123\r\n" + + "\r\n" + + shortBody)) + res, err := ReadResponse(br, &Request{Method: "GET"}) + if err != nil { + t.Fatal(err) + } + if res.ContentLength != 123 { + t.Fatalf("Content-Length = %d; want 123", res.ContentLength) + } + var buf bytes.Buffer + n, err := io.Copy(&buf, res.Body) + if n != int64(len(shortBody)) { + t.Errorf("Copied %d bytes; want %d, len(%q)", n, len(shortBody), shortBody) + } + if buf.String() != shortBody { + t.Errorf("Read body %q; want %q", buf.String(), shortBody) + } + if err != io.ErrUnexpectedEOF { + t.Errorf("io.Copy error = %#v; want io.ErrUnexpectedEOF", err) + } +} + +func TestReadResponseUnexpectedEOF(t *testing.T) { + br := bufio.NewReader(strings.NewReader("HTTP/1.1 301 Moved Permanently\r\n" + + "Location: http://example.com")) + _, err := ReadResponse(br, nil) + if err != io.ErrUnexpectedEOF { + t.Errorf("ReadResponse = %v; want io.ErrUnexpectedEOF", err) + } +} + +func TestNeedsSniff(t *testing.T) { + // needsSniff returns true with an empty response. + r := &response{} + if got, want := r.needsSniff(), true; got != want { + t.Errorf("needsSniff = %t; want %t", got, want) + } + // needsSniff returns false when Content-Type = nil. + r.handlerHeader = Header{"Content-Type": nil} + if got, want := r.needsSniff(), false; got != want { + t.Errorf("needsSniff empty Content-Type = %t; want %t", got, want) + } +} diff --git a/src/net/http/responsewrite_test.go b/src/net/http/responsewrite_test.go new file mode 100644 index 000000000..585b13b85 --- /dev/null +++ b/src/net/http/responsewrite_test.go @@ -0,0 +1,226 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "io/ioutil" + "strings" + "testing" +) + +type respWriteTest struct { + Resp Response + Raw string +} + +func TestResponseWrite(t *testing.T) { + respWriteTests := []respWriteTest{ + // HTTP/1.0, identity coding; no trailer + { + Response{ + StatusCode: 503, + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("abcdef")), + ContentLength: 6, + }, + + "HTTP/1.0 503 Service Unavailable\r\n" + + "Content-Length: 6\r\n\r\n" + + "abcdef", + }, + // Unchunked response without Content-Length. + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("abcdef")), + ContentLength: -1, + }, + "HTTP/1.0 200 OK\r\n" + + "\r\n" + + "abcdef", + }, + // HTTP/1.1 response with unknown length and Connection: close + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("abcdef")), + ContentLength: -1, + Close: true, + }, + "HTTP/1.1 200 OK\r\n" + + "Connection: close\r\n" + + "\r\n" + + "abcdef", + }, + // HTTP/1.1 response with unknown length and not setting connection: close + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("abcdef")), + ContentLength: -1, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Connection: close\r\n" + + "\r\n" + + "abcdef", + }, + // HTTP/1.1 response with unknown length and not setting connection: close, but + // setting chunked. + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("abcdef")), + ContentLength: -1, + TransferEncoding: []string{"chunked"}, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "6\r\nabcdef\r\n0\r\n\r\n", + }, + // HTTP/1.1 response 0 content-length, and nil body + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: nil, + ContentLength: 0, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + }, + // HTTP/1.1 response 0 content-length, and non-nil empty body + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("")), + ContentLength: 0, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + }, + // HTTP/1.1 response 0 content-length, and non-nil non-empty body + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq11("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("foo")), + ContentLength: 0, + Close: false, + }, + "HTTP/1.1 200 OK\r\n" + + "Connection: close\r\n" + + "\r\nfoo", + }, + // HTTP/1.1, chunked coding; empty trailer; close + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(strings.NewReader("abcdef")), + ContentLength: 6, + TransferEncoding: []string{"chunked"}, + Close: true, + }, + + "HTTP/1.1 200 OK\r\n" + + "Connection: close\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "6\r\nabcdef\r\n0\r\n\r\n", + }, + + // Header value with a newline character (Issue 914). + // Also tests removal of leading and trailing whitespace. + { + Response{ + StatusCode: 204, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{ + "Foo": []string{" Bar\nBaz "}, + }, + Body: nil, + ContentLength: 0, + TransferEncoding: []string{"chunked"}, + Close: true, + }, + + "HTTP/1.1 204 No Content\r\n" + + "Connection: close\r\n" + + "Foo: Bar Baz\r\n" + + "\r\n", + }, + + // Want a single Content-Length header. Fixing issue 8180 where + // there were two. + { + Response{ + StatusCode: StatusOK, + ProtoMajor: 1, + ProtoMinor: 1, + Request: &Request{Method: "POST"}, + Header: Header{}, + ContentLength: 0, + TransferEncoding: nil, + Body: nil, + }, + "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n", + }, + } + + for i := range respWriteTests { + tt := &respWriteTests[i] + var braw bytes.Buffer + err := tt.Resp.Write(&braw) + if err != nil { + t.Errorf("error writing #%d: %s", i, err) + continue + } + sraw := braw.String() + if sraw != tt.Raw { + t.Errorf("Test %d, expecting:\n%q\nGot:\n%q\n", i, tt.Raw, sraw) + continue + } + } +} diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go new file mode 100644 index 000000000..5e0a0053c --- /dev/null +++ b/src/net/http/serve_test.go @@ -0,0 +1,3094 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// End-to-end serving tests + +package http_test + +import ( + "bufio" + "bytes" + "crypto/tls" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "math/rand" + "net" + . "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "os" + "os/exec" + "reflect" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "syscall" + "testing" + "time" +) + +type dummyAddr string +type oneConnListener struct { + conn net.Conn +} + +func (l *oneConnListener) Accept() (c net.Conn, err error) { + c = l.conn + if c == nil { + err = io.EOF + return + } + err = nil + l.conn = nil + return +} + +func (l *oneConnListener) Close() error { + return nil +} + +func (l *oneConnListener) Addr() net.Addr { + return dummyAddr("test-address") +} + +func (a dummyAddr) Network() string { + return string(a) +} + +func (a dummyAddr) String() string { + return string(a) +} + +type noopConn struct{} + +func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") } +func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") } +func (noopConn) SetDeadline(t time.Time) error { return nil } +func (noopConn) SetReadDeadline(t time.Time) error { return nil } +func (noopConn) SetWriteDeadline(t time.Time) error { return nil } + +type rwTestConn struct { + io.Reader + io.Writer + noopConn + + closeFunc func() error // called if non-nil + closec chan bool // else, if non-nil, send value to it on close +} + +func (c *rwTestConn) Close() error { + if c.closeFunc != nil { + return c.closeFunc() + } + select { + case c.closec <- true: + default: + } + return nil +} + +type testConn struct { + readBuf bytes.Buffer + writeBuf bytes.Buffer + closec chan bool // if non-nil, send value to it on close + noopConn +} + +func (c *testConn) Read(b []byte) (int, error) { + return c.readBuf.Read(b) +} + +func (c *testConn) Write(b []byte) (int, error) { + return c.writeBuf.Write(b) +} + +func (c *testConn) Close() error { + select { + case c.closec <- true: + default: + } + return nil +} + +// reqBytes treats req as a request (with \n delimiters) and returns it with \r\n delimiters, +// ending in \r\n\r\n +func reqBytes(req string) []byte { + return []byte(strings.Replace(strings.TrimSpace(req), "\n", "\r\n", -1) + "\r\n\r\n") +} + +type handlerTest struct { + handler Handler +} + +func newHandlerTest(h Handler) handlerTest { + return handlerTest{h} +} + +func (ht handlerTest) rawResponse(req string) string { + reqb := reqBytes(req) + var output bytes.Buffer + conn := &rwTestConn{ + Reader: bytes.NewReader(reqb), + Writer: &output, + closec: make(chan bool, 1), + } + ln := &oneConnListener{conn: conn} + go Serve(ln, ht.handler) + <-conn.closec + return output.String() +} + +func TestConsumingBodyOnNextConn(t *testing.T) { + conn := new(testConn) + for i := 0; i < 2; i++ { + conn.readBuf.Write([]byte( + "POST / HTTP/1.1\r\n" + + "Host: test\r\n" + + "Content-Length: 11\r\n" + + "\r\n" + + "foo=1&bar=1")) + } + + reqNum := 0 + ch := make(chan *Request) + servech := make(chan error) + listener := &oneConnListener{conn} + handler := func(res ResponseWriter, req *Request) { + reqNum++ + ch <- req + } + + go func() { + servech <- Serve(listener, HandlerFunc(handler)) + }() + + var req *Request + req = <-ch + if req == nil { + t.Fatal("Got nil first request.") + } + if req.Method != "POST" { + t.Errorf("For request #1's method, got %q; expected %q", + req.Method, "POST") + } + + req = <-ch + if req == nil { + t.Fatal("Got nil first request.") + } + if req.Method != "POST" { + t.Errorf("For request #2's method, got %q; expected %q", + req.Method, "POST") + } + + if serveerr := <-servech; serveerr != io.EOF { + t.Errorf("Serve returned %q; expected EOF", serveerr) + } +} + +type stringHandler string + +func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) { + w.Header().Set("Result", string(s)) +} + +var handlers = []struct { + pattern string + msg string +}{ + {"/", "Default"}, + {"/someDir/", "someDir"}, + {"someHost.com/someDir/", "someHost.com/someDir"}, +} + +var vtests = []struct { + url string + expected string +}{ + {"http://localhost/someDir/apage", "someDir"}, + {"http://localhost/otherDir/apage", "Default"}, + {"http://someHost.com/someDir/apage", "someHost.com/someDir"}, + {"http://otherHost.com/someDir/apage", "someDir"}, + {"http://otherHost.com/aDir/apage", "Default"}, + // redirections for trees + {"http://localhost/someDir", "/someDir/"}, + {"http://someHost.com/someDir", "/someDir/"}, +} + +func TestHostHandlers(t *testing.T) { + defer afterTest(t) + mux := NewServeMux() + for _, h := range handlers { + mux.Handle(h.pattern, stringHandler(h.msg)) + } + ts := httptest.NewServer(mux) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + cc := httputil.NewClientConn(conn, nil) + for _, vt := range vtests { + var r *Response + var req Request + if req.URL, err = url.Parse(vt.url); err != nil { + t.Errorf("cannot parse url: %v", err) + continue + } + if err := cc.Write(&req); err != nil { + t.Errorf("writing request: %v", err) + continue + } + r, err := cc.Read(&req) + if err != nil { + t.Errorf("reading response: %v", err) + continue + } + switch r.StatusCode { + case StatusOK: + s := r.Header.Get("Result") + if s != vt.expected { + t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected) + } + case StatusMovedPermanently: + s := r.Header.Get("Location") + if s != vt.expected { + t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected) + } + default: + t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode) + } + } +} + +var serveMuxRegister = []struct { + pattern string + h Handler +}{ + {"/dir/", serve(200)}, + {"/search", serve(201)}, + {"codesearch.google.com/search", serve(202)}, + {"codesearch.google.com/", serve(203)}, + {"example.com/", HandlerFunc(checkQueryStringHandler)}, +} + +// serve returns a handler that sends a response with the given code. +func serve(code int) HandlerFunc { + return func(w ResponseWriter, r *Request) { + w.WriteHeader(code) + } +} + +// checkQueryStringHandler checks if r.URL.RawQuery has the same value +// as the URL excluding the scheme and the query string and sends 200 +// response code if it is, 500 otherwise. +func checkQueryStringHandler(w ResponseWriter, r *Request) { + u := *r.URL + u.Scheme = "http" + u.Host = r.Host + u.RawQuery = "" + if "http://"+r.URL.RawQuery == u.String() { + w.WriteHeader(200) + } else { + w.WriteHeader(500) + } +} + +var serveMuxTests = []struct { + method string + host string + path string + code int + pattern string +}{ + {"GET", "google.com", "/", 404, ""}, + {"GET", "google.com", "/dir", 301, "/dir/"}, + {"GET", "google.com", "/dir/", 200, "/dir/"}, + {"GET", "google.com", "/dir/file", 200, "/dir/"}, + {"GET", "google.com", "/search", 201, "/search"}, + {"GET", "google.com", "/search/", 404, ""}, + {"GET", "google.com", "/search/foo", 404, ""}, + {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"}, + {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"}, + {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"}, + {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"}, + {"GET", "images.google.com", "/search", 201, "/search"}, + {"GET", "images.google.com", "/search/", 404, ""}, + {"GET", "images.google.com", "/search/foo", 404, ""}, + {"GET", "google.com", "/../search", 301, "/search"}, + {"GET", "google.com", "/dir/..", 301, ""}, + {"GET", "google.com", "/dir/..", 301, ""}, + {"GET", "google.com", "/dir/./file", 301, "/dir/"}, + + // The /foo -> /foo/ redirect applies to CONNECT requests + // but the path canonicalization does not. + {"CONNECT", "google.com", "/dir", 301, "/dir/"}, + {"CONNECT", "google.com", "/../search", 404, ""}, + {"CONNECT", "google.com", "/dir/..", 200, "/dir/"}, + {"CONNECT", "google.com", "/dir/..", 200, "/dir/"}, + {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"}, +} + +func TestServeMuxHandler(t *testing.T) { + mux := NewServeMux() + for _, e := range serveMuxRegister { + mux.Handle(e.pattern, e.h) + } + + for _, tt := range serveMuxTests { + r := &Request{ + Method: tt.method, + Host: tt.host, + URL: &url.URL{ + Path: tt.path, + }, + } + h, pattern := mux.Handler(r) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, r) + if pattern != tt.pattern || rr.Code != tt.code { + t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern) + } + } +} + +var serveMuxTests2 = []struct { + method string + host string + url string + code int + redirOk bool +}{ + {"GET", "google.com", "/", 404, false}, + {"GET", "example.com", "/test/?example.com/test/", 200, false}, + {"GET", "example.com", "test/?example.com/test/", 200, true}, +} + +// TestServeMuxHandlerRedirects tests that automatic redirects generated by +// mux.Handler() shouldn't clear the request's query string. +func TestServeMuxHandlerRedirects(t *testing.T) { + mux := NewServeMux() + for _, e := range serveMuxRegister { + mux.Handle(e.pattern, e.h) + } + + for _, tt := range serveMuxTests2 { + tries := 1 + turl := tt.url + for tries > 0 { + u, e := url.Parse(turl) + if e != nil { + t.Fatal(e) + } + r := &Request{ + Method: tt.method, + Host: tt.host, + URL: u, + } + h, _ := mux.Handler(r) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, r) + if rr.Code != 301 { + if rr.Code != tt.code { + t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code) + } + break + } + if !tt.redirOk { + t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url) + break + } + turl = rr.HeaderMap.Get("Location") + tries-- + } + if tries < 0 { + t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url) + } + } +} + +// Tests for http://code.google.com/p/go/issues/detail?id=900 +func TestMuxRedirectLeadingSlashes(t *testing.T) { + paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"} + for _, path := range paths { + req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n"))) + if err != nil { + t.Errorf("%s", err) + } + mux := NewServeMux() + resp := httptest.NewRecorder() + + mux.ServeHTTP(resp, req) + + if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected { + t.Errorf("Expected Location header set to %q; got %q", expected, loc) + return + } + + if code, expected := resp.Code, StatusMovedPermanently; code != expected { + t.Errorf("Expected response code of StatusMovedPermanently; got %d", code) + return + } + } +} + +func TestServerTimeouts(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } + defer afterTest(t) + reqNum := 0 + ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + reqNum++ + fmt.Fprintf(res, "req=%d", reqNum) + })) + ts.Config.ReadTimeout = 250 * time.Millisecond + ts.Config.WriteTimeout = 250 * time.Millisecond + ts.Start() + defer ts.Close() + + // Hit the HTTP server successfully. + tr := &Transport{DisableKeepAlives: true} // they interfere with this test + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + r, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("http Get #1: %v", err) + } + got, _ := ioutil.ReadAll(r.Body) + expected := "req=1" + if string(got) != expected { + t.Errorf("Unexpected response for request #1; got %q; expected %q", + string(got), expected) + } + + // Slow client that should timeout. + t1 := time.Now() + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + buf := make([]byte, 1) + n, err := conn.Read(buf) + latency := time.Since(t1) + if n != 0 || err != io.EOF { + t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF) + } + if latency < 200*time.Millisecond /* fudge from 250 ms above */ { + t.Errorf("got EOF after %s, want >= %s", latency, 200*time.Millisecond) + } + + // Hit the HTTP server successfully again, verifying that the + // previous slow connection didn't run our handler. (that we + // get "req=2", not "req=3") + r, err = Get(ts.URL) + if err != nil { + t.Fatalf("http Get #2: %v", err) + } + got, _ = ioutil.ReadAll(r.Body) + expected = "req=2" + if string(got) != expected { + t.Errorf("Get #2 got %q, want %q", string(got), expected) + } + + if !testing.Short() { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + go io.Copy(ioutil.Discard, conn) + for i := 0; i < 5; i++ { + _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n")) + if err != nil { + t.Fatalf("on write %d: %v", i, err) + } + time.Sleep(ts.Config.ReadTimeout / 2) + } + } +} + +// golang.org/issue/4741 -- setting only a write timeout that triggers +// shouldn't cause a handler to block forever on reads (next HTTP +// request) that will never happen. +func TestOnlyWriteTimeout(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } + defer afterTest(t) + var conn net.Conn + var afterTimeoutErrc = make(chan error, 1) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) { + buf := make([]byte, 512<<10) + _, err := w.Write(buf) + if err != nil { + t.Errorf("handler Write error: %v", err) + return + } + conn.SetWriteDeadline(time.Now().Add(-30 * time.Second)) + _, err = w.Write(buf) + afterTimeoutErrc <- err + })) + ts.Listener = trackLastConnListener{ts.Listener, &conn} + ts.Start() + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + errc := make(chan error) + go func() { + res, err := c.Get(ts.URL) + if err != nil { + errc <- err + return + } + _, err = io.Copy(ioutil.Discard, res.Body) + errc <- err + }() + select { + case err := <-errc: + if err == nil { + t.Errorf("expected an error from Get request") + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for Get error") + } + if err := <-afterTimeoutErrc; err == nil { + t.Error("expected write error after timeout") + } +} + +// trackLastConnListener tracks the last net.Conn that was accepted. +type trackLastConnListener struct { + net.Listener + last *net.Conn // destination +} + +func (l trackLastConnListener) Accept() (c net.Conn, err error) { + c, err = l.Listener.Accept() + *l.last = c + return +} + +// TestIdentityResponse verifies that a handler can unset +func TestIdentityResponse(t *testing.T) { + defer afterTest(t) + handler := HandlerFunc(func(rw ResponseWriter, req *Request) { + rw.Header().Set("Content-Length", "3") + rw.Header().Set("Transfer-Encoding", req.FormValue("te")) + switch { + case req.FormValue("overwrite") == "1": + _, err := rw.Write([]byte("foo TOO LONG")) + if err != ErrContentLength { + t.Errorf("expected ErrContentLength; got %v", err) + } + case req.FormValue("underwrite") == "1": + rw.Header().Set("Content-Length", "500") + rw.Write([]byte("too short")) + default: + rw.Write([]byte("foo")) + } + }) + + ts := httptest.NewServer(handler) + defer ts.Close() + + // Note: this relies on the assumption (which is true) that + // Get sends HTTP/1.1 or greater requests. Otherwise the + // server wouldn't have the choice to send back chunked + // responses. + for _, te := range []string{"", "identity"} { + url := ts.URL + "/?te=" + te + res, err := Get(url) + if err != nil { + t.Fatalf("error with Get of %s: %v", url, err) + } + if cl, expected := res.ContentLength, int64(3); cl != expected { + t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl) + } + if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected { + t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl) + } + if tl, expected := len(res.TransferEncoding), 0; tl != expected { + t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)", + url, expected, tl, res.TransferEncoding) + } + res.Body.Close() + } + + // Verify that ErrContentLength is returned + url := ts.URL + "/?overwrite=1" + res, err := Get(url) + if err != nil { + t.Fatalf("error with Get of %s: %v", url, err) + } + res.Body.Close() + + // Verify that the connection is closed when the declared Content-Length + // is larger than what the handler wrote. + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n")) + if err != nil { + t.Fatalf("error writing: %v", err) + } + + // The ReadAll will hang for a failing test, so use a Timer to + // fail explicitly. + goTimeout(t, 2*time.Second, func() { + got, _ := ioutil.ReadAll(conn) + expectedSuffix := "\r\n\r\ntoo short" + if !strings.HasSuffix(string(got), expectedSuffix) { + t.Errorf("Expected output to end with %q; got response body %q", + expectedSuffix, string(got)) + } + }) +} + +func testTCPConnectionCloses(t *testing.T, req string, h Handler) { + defer afterTest(t) + s := httptest.NewServer(h) + defer s.Close() + + conn, err := net.Dial("tcp", s.Listener.Addr().String()) + if err != nil { + t.Fatal("dial error:", err) + } + defer conn.Close() + + _, err = fmt.Fprint(conn, req) + if err != nil { + t.Fatal("print error:", err) + } + + r := bufio.NewReader(conn) + res, err := ReadResponse(r, &Request{Method: "GET"}) + if err != nil { + t.Fatal("ReadResponse error:", err) + } + + didReadAll := make(chan bool, 1) + go func() { + select { + case <-time.After(5 * time.Second): + t.Error("body not closed after 5s") + return + case <-didReadAll: + } + }() + + _, err = ioutil.ReadAll(r) + if err != nil { + t.Fatal("read error:", err) + } + didReadAll <- true + + if !res.Close { + t.Errorf("Response.Close = false; want true") + } +} + +// TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive. +func TestServeHTTP10Close(t *testing.T) { + testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "testdata/file") + })) +} + +// TestClientCanClose verifies that clients can also force a connection to close. +func TestClientCanClose(t *testing.T) { + testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + // Nothing. + })) +} + +// TestHandlersCanSetConnectionClose verifies that handlers can force a connection to close, +// even for HTTP/1.1 requests. +func TestHandlersCanSetConnectionClose11(t *testing.T) { + testTCPConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + })) +} + +func TestHandlersCanSetConnectionClose10(t *testing.T) { + testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + })) +} + +func TestSetsRemoteAddr(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%s", r.RemoteAddr) + })) + defer ts.Close() + + res, err := Get(ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + ip := string(body) + if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") { + t.Fatalf("Expected local addr; got %q", ip) + } +} + +func TestChunkedResponseHeaders(t *testing.T) { + defer afterTest(t) + log.SetOutput(ioutil.Discard) // is noisy otherwise + defer log.SetOutput(os.Stderr) + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted + w.(Flusher).Flush() + fmt.Fprintf(w, "I am a chunked response.") + })) + defer ts.Close() + + res, err := Get(ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + defer res.Body.Close() + if g, e := res.ContentLength, int64(-1); g != e { + t.Errorf("expected ContentLength of %d; got %d", e, g) + } + if g, e := res.TransferEncoding, []string{"chunked"}; !reflect.DeepEqual(g, e) { + t.Errorf("expected TransferEncoding of %v; got %v", e, g) + } + if _, haveCL := res.Header["Content-Length"]; haveCL { + t.Errorf("Unexpected Content-Length") + } +} + +func TestIdentityResponseHeaders(t *testing.T) { + defer afterTest(t) + log.SetOutput(ioutil.Discard) // is noisy otherwise + defer log.SetOutput(os.Stderr) + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Transfer-Encoding", "identity") + w.(Flusher).Flush() + fmt.Fprintf(w, "I am an identity response.") + })) + defer ts.Close() + + res, err := Get(ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + defer res.Body.Close() + + if g, e := res.TransferEncoding, []string(nil); !reflect.DeepEqual(g, e) { + t.Errorf("expected TransferEncoding of %v; got %v", e, g) + } + if _, haveCL := res.Header["Content-Length"]; haveCL { + t.Errorf("Unexpected Content-Length") + } + if !res.Close { + t.Errorf("expected Connection: close; got %v", res.Close) + } +} + +// Test304Responses verifies that 304s don't declare that they're +// chunking in their response headers and aren't allowed to produce +// output. +func Test304Responses(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusNotModified) + _, err := w.Write([]byte("illegal body")) + if err != ErrBodyNotAllowed { + t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) + } + })) + defer ts.Close() + res, err := Get(ts.URL) + if err != nil { + t.Error(err) + } + if len(res.TransferEncoding) > 0 { + t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + if len(body) > 0 { + t.Errorf("got unexpected body %q", string(body)) + } +} + +// TestHeadResponses verifies that all MIME type sniffing and Content-Length +// counting of GET requests also happens on HEAD requests. +func TestHeadResponses(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := w.Write([]byte("<html>")) + if err != nil { + t.Errorf("ResponseWriter.Write: %v", err) + } + + // Also exercise the ReaderFrom path + _, err = io.Copy(w, strings.NewReader("789a")) + if err != nil { + t.Errorf("Copy(ResponseWriter, ...): %v", err) + } + })) + defer ts.Close() + res, err := Head(ts.URL) + if err != nil { + t.Error(err) + } + if len(res.TransferEncoding) > 0 { + t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) + } + if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" { + t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct) + } + if v := res.ContentLength; v != 10 { + t.Errorf("Content-Length: %d; want 10", v) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + if len(body) > 0 { + t.Errorf("got unexpected body %q", string(body)) + } +} + +func TestTLSHandshakeTimeout(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } + defer afterTest(t) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + errc := make(chanWriter, 10) // but only expecting 1 + ts.Config.ReadTimeout = 250 * time.Millisecond + ts.Config.ErrorLog = log.New(errc, "", 0) + ts.StartTLS() + defer ts.Close() + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + goTimeout(t, 10*time.Second, func() { + var buf [1]byte + n, err := conn.Read(buf[:]) + if err == nil || n != 0 { + t.Errorf("Read = %d, %v; want an error and no bytes", n, err) + } + }) + select { + case v := <-errc: + if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") { + t.Errorf("expected a TLS handshake timeout error; got %q", v) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout waiting for logged error") + } +} + +func TestTLSServer(t *testing.T) { + defer afterTest(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.TLS != nil { + w.Header().Set("X-TLS-Set", "true") + if r.TLS.HandshakeComplete { + w.Header().Set("X-TLS-HandshakeComplete", "true") + } + } + })) + ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) + defer ts.Close() + + // Connect an idle TCP connection to this server before we run + // our real tests. This idle connection used to block forever + // in the TLS handshake, preventing future connections from + // being accepted. It may prevent future accidental blocking + // in newConn. + idleConn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer idleConn.Close() + goTimeout(t, 10*time.Second, func() { + if !strings.HasPrefix(ts.URL, "https://") { + t.Errorf("expected test TLS server to start with https://, got %q", ts.URL) + return + } + noVerifyTransport := &Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + client := &Client{Transport: noVerifyTransport} + res, err := client.Get(ts.URL) + if err != nil { + t.Error(err) + return + } + if res == nil { + t.Errorf("got nil Response") + return + } + defer res.Body.Close() + if res.Header.Get("X-TLS-Set") != "true" { + t.Errorf("expected X-TLS-Set response header") + return + } + if res.Header.Get("X-TLS-HandshakeComplete") != "true" { + t.Errorf("expected X-TLS-HandshakeComplete header") + } + }) +} + +type serverExpectTest struct { + contentLength int // of request body + chunked bool + expectation string // e.g. "100-continue" + readBody bool // whether handler should read the body (if false, sends StatusUnauthorized) + expectedResponse string // expected substring in first line of http response +} + +func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest { + return serverExpectTest{ + contentLength: contentLength, + expectation: expectation, + readBody: readBody, + expectedResponse: expectedResponse, + } +} + +var serverExpectTests = []serverExpectTest{ + // Normal 100-continues, case-insensitive. + expectTest(100, "100-continue", true, "100 Continue"), + expectTest(100, "100-cOntInUE", true, "100 Continue"), + + // No 100-continue. + expectTest(100, "", true, "200 OK"), + + // 100-continue but requesting client to deny us, + // so it never reads the body. + expectTest(100, "100-continue", false, "401 Unauthorized"), + // Likewise without 100-continue: + expectTest(100, "", false, "401 Unauthorized"), + + // Non-standard expectations are failures + expectTest(0, "a-pony", false, "417 Expectation Failed"), + + // Expect-100 requested but no body (is apparently okay: Issue 7625) + expectTest(0, "100-continue", true, "200 OK"), + // Expect-100 requested but handler doesn't read the body + expectTest(0, "100-continue", false, "401 Unauthorized"), + // Expect-100 continue with no body, but a chunked body. + { + expectation: "100-continue", + readBody: true, + chunked: true, + expectedResponse: "100 Continue", + }, +} + +// Tests that the server responds to the "Expect" request header +// correctly. +func TestServerExpect(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + // Note using r.FormValue("readbody") because for POST + // requests that would read from r.Body, which we only + // conditionally want to do. + if strings.Contains(r.URL.RawQuery, "readbody=true") { + ioutil.ReadAll(r.Body) + w.Write([]byte("Hi")) + } else { + w.WriteHeader(StatusUnauthorized) + } + })) + defer ts.Close() + + runTest := func(test serverExpectTest) { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + // Only send the body immediately if we're acting like an HTTP client + // that doesn't send 100-continue expectations. + writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue" + + go func() { + contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength) + if test.chunked { + contentLen = "Transfer-Encoding: chunked" + } + _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+ + "Connection: close\r\n"+ + "%s\r\n"+ + "Expect: %s\r\nHost: foo\r\n\r\n", + test.readBody, contentLen, test.expectation) + if err != nil { + t.Errorf("On test %#v, error writing request headers: %v", test, err) + return + } + if writeBody { + var targ io.WriteCloser = struct { + io.Writer + io.Closer + }{ + conn, + ioutil.NopCloser(nil), + } + if test.chunked { + targ = httputil.NewChunkedWriter(conn) + } + body := strings.Repeat("A", test.contentLength) + _, err = fmt.Fprint(targ, body) + if err == nil { + err = targ.Close() + } + if err != nil { + if !test.readBody { + // Server likely already hung up on us. + // See larger comment below. + t.Logf("On test %#v, acceptable error writing request body: %v", test, err) + return + } + t.Errorf("On test %#v, error writing request body: %v", test, err) + } + } + }() + bufr := bufio.NewReader(conn) + line, err := bufr.ReadString('\n') + if err != nil { + if writeBody && !test.readBody { + // This is an acceptable failure due to a possible TCP race: + // We were still writing data and the server hung up on us. A TCP + // implementation may send a RST if our request body data was known + // to be lost, which may trigger our reads to fail. + // See RFC 1122 page 88. + t.Logf("On test %#v, acceptable error from ReadString: %v", test, err) + return + } + t.Fatalf("On test %#v, ReadString: %v", test, err) + } + if !strings.Contains(line, test.expectedResponse) { + t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse) + } + } + + for _, test := range serverExpectTests { + runTest(test) + } +} + +// Under a ~256KB (maxPostHandlerReadBytes) threshold, the server +// should consume client request bodies that a handler didn't read. +func TestServerUnreadRequestBodyLittle(t *testing.T) { + conn := new(testConn) + body := strings.Repeat("x", 100<<10) + conn.readBuf.Write([]byte(fmt.Sprintf( + "POST / HTTP/1.1\r\n"+ + "Host: test\r\n"+ + "Content-Length: %d\r\n"+ + "\r\n", len(body)))) + conn.readBuf.Write([]byte(body)) + + done := make(chan bool) + + ls := &oneConnListener{conn} + go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { + defer close(done) + if conn.readBuf.Len() < len(body)/2 { + t.Errorf("on request, read buffer length is %d; expected about 100 KB", conn.readBuf.Len()) + } + rw.WriteHeader(200) + rw.(Flusher).Flush() + if g, e := conn.readBuf.Len(), 0; g != e { + t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e) + } + if c := rw.Header().Get("Connection"); c != "" { + t.Errorf(`Connection header = %q; want ""`, c) + } + })) + <-done +} + +// Over a ~256KB (maxPostHandlerReadBytes) threshold, the server +// should ignore client request bodies that a handler didn't read +// and close the connection. +func TestServerUnreadRequestBodyLarge(t *testing.T) { + conn := new(testConn) + body := strings.Repeat("x", 1<<20) + conn.readBuf.Write([]byte(fmt.Sprintf( + "POST / HTTP/1.1\r\n"+ + "Host: test\r\n"+ + "Content-Length: %d\r\n"+ + "\r\n", len(body)))) + conn.readBuf.Write([]byte(body)) + conn.closec = make(chan bool, 1) + + ls := &oneConnListener{conn} + go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { + if conn.readBuf.Len() < len(body)/2 { + t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len()) + } + rw.WriteHeader(200) + rw.(Flusher).Flush() + if conn.readBuf.Len() < len(body)/2 { + t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len()) + } + })) + <-conn.closec + + if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") { + t.Errorf("Expected a Connection: close header; got response: %s", res) + } +} + +func TestTimeoutHandler(t *testing.T) { + defer afterTest(t) + sendHi := make(chan bool, 1) + writeErrors := make(chan error, 1) + sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { + <-sendHi + _, werr := w.Write([]byte("hi")) + writeErrors <- werr + }) + timeout := make(chan time.Time, 1) // write to this to force timeouts + ts := httptest.NewServer(NewTestTimeoutHandler(sayHi, timeout)) + defer ts.Close() + + // Succeed without timing out: + sendHi <- true + res, err := Get(ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusOK; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ := ioutil.ReadAll(res.Body) + if g, e := string(body), "hi"; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g := <-writeErrors; g != nil { + t.Errorf("got unexpected Write error on first request: %v", g) + } + + // Times out: + timeout <- time.Time{} + res, err = Get(ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusServiceUnavailable; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ = ioutil.ReadAll(res.Body) + if !strings.Contains(string(body), "<title>Timeout</title>") { + t.Errorf("expected timeout body; got %q", string(body)) + } + + // Now make the previously-timed out handler speak again, + // which verifies the panic is handled: + sendHi <- true + if g, e := <-writeErrors, ErrHandlerTimeout; g != e { + t.Errorf("expected Write error of %v; got %v", e, g) + } +} + +// See issues 8209 and 8414. +func TestTimeoutHandlerRace(t *testing.T) { + defer afterTest(t) + + delayHi := HandlerFunc(func(w ResponseWriter, r *Request) { + ms, _ := strconv.Atoi(r.URL.Path[1:]) + if ms == 0 { + ms = 1 + } + for i := 0; i < ms; i++ { + w.Write([]byte("hi")) + time.Sleep(time.Millisecond) + } + }) + + ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, "")) + defer ts.Close() + + var wg sync.WaitGroup + gate := make(chan bool, 10) + n := 50 + if testing.Short() { + n = 10 + gate = make(chan bool, 3) + } + for i := 0; i < n; i++ { + gate <- true + wg.Add(1) + go func() { + defer wg.Done() + defer func() { <-gate }() + res, err := Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50))) + if err == nil { + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + } + }() + } + wg.Wait() +} + +// See issues 8209 and 8414. +func TestTimeoutHandlerRaceHeader(t *testing.T) { + defer afterTest(t) + + delay204 := HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(204) + }) + + ts := httptest.NewServer(TimeoutHandler(delay204, time.Nanosecond, "")) + defer ts.Close() + + var wg sync.WaitGroup + gate := make(chan bool, 50) + n := 500 + if testing.Short() { + n = 10 + } + for i := 0; i < n; i++ { + gate <- true + wg.Add(1) + go func() { + defer wg.Done() + defer func() { <-gate }() + res, err := Get(ts.URL) + if err != nil { + t.Error(err) + return + } + defer res.Body.Close() + io.Copy(ioutil.Discard, res.Body) + }() + } + wg.Wait() +} + +// Verifies we don't path.Clean() on the wrong parts in redirects. +func TestRedirectMunging(t *testing.T) { + req, _ := NewRequest("GET", "http://example.com/", nil) + + resp := httptest.NewRecorder() + Redirect(resp, req, "/foo?next=http://bar.com/", 302) + if g, e := resp.Header().Get("Location"), "/foo?next=http://bar.com/"; g != e { + t.Errorf("Location header was %q; want %q", g, e) + } + + resp = httptest.NewRecorder() + Redirect(resp, req, "http://localhost:8080/_ah/login?continue=http://localhost:8080/", 302) + if g, e := resp.Header().Get("Location"), "http://localhost:8080/_ah/login?continue=http://localhost:8080/"; g != e { + t.Errorf("Location header was %q; want %q", g, e) + } +} + +func TestRedirectBadPath(t *testing.T) { + // This used to crash. It's not valid input (bad path), but it + // shouldn't crash. + rr := httptest.NewRecorder() + req := &Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Path: "not-empty-but-no-leading-slash", // bogus + }, + } + Redirect(rr, req, "", 304) + if rr.Code != 304 { + t.Errorf("Code = %d; want 304", rr.Code) + } +} + +// TestZeroLengthPostAndResponse exercises an optimization done by the Transport: +// when there is no body (either because the method doesn't permit a body, or an +// explicit Content-Length of zero is present), then the transport can re-use the +// connection immediately. But when it re-uses the connection, it typically closes +// the previous request's body, which is not optimal for zero-lengthed bodies, +// as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF. +func TestZeroLengthPostAndResponse(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + all, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatalf("handler ReadAll: %v", err) + } + if len(all) != 0 { + t.Errorf("handler got %d bytes; expected 0", len(all)) + } + rw.Header().Set("Content-Length", "0") + })) + defer ts.Close() + + req, err := NewRequest("POST", ts.URL, strings.NewReader("")) + if err != nil { + t.Fatal(err) + } + req.ContentLength = 0 + + var resp [5]*Response + for i := range resp { + resp[i], err = DefaultClient.Do(req) + if err != nil { + t.Fatalf("client post #%d: %v", i, err) + } + } + + for i := range resp { + all, err := ioutil.ReadAll(resp[i].Body) + if err != nil { + t.Fatalf("req #%d: client ReadAll: %v", i, err) + } + if len(all) != 0 { + t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all)) + } + } +} + +func TestHandlerPanicNil(t *testing.T) { + testHandlerPanic(t, false, nil) +} + +func TestHandlerPanic(t *testing.T) { + testHandlerPanic(t, false, "intentional death for testing") +} + +func TestHandlerPanicWithHijack(t *testing.T) { + testHandlerPanic(t, true, "intentional death for testing") +} + +func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { + defer afterTest(t) + // Unlike the other tests that set the log output to ioutil.Discard + // to quiet the output, this test uses a pipe. The pipe serves three + // purposes: + // + // 1) The log.Print from the http server (generated by the caught + // panic) will go to the pipe instead of stderr, making the + // output quiet. + // + // 2) We read from the pipe to verify that the handler + // actually caught the panic and logged something. + // + // 3) The blocking Read call prevents this TestHandlerPanic + // function from exiting before the HTTP server handler + // finishes crashing. If this text function exited too + // early (and its defer log.SetOutput(os.Stderr) ran), + // then the crash output could spill into the next test. + pr, pw := io.Pipe() + log.SetOutput(pw) + defer log.SetOutput(os.Stderr) + defer pw.Close() + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if withHijack { + rwc, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Logf("unexpected error: %v", err) + } + defer rwc.Close() + } + panic(panicValue) + })) + defer ts.Close() + + // Do a blocking read on the log output pipe so its logging + // doesn't bleed into the next test. But wait only 5 seconds + // for it. + done := make(chan bool, 1) + go func() { + buf := make([]byte, 4<<10) + _, err := pr.Read(buf) + pr.Close() + if err != nil && err != io.EOF { + t.Error(err) + } + done <- true + }() + + _, err := Get(ts.URL) + if err == nil { + t.Logf("expected an error") + } + + if panicValue == nil { + return + } + + select { + case <-done: + return + case <-time.After(5 * time.Second): + t.Fatal("expected server handler to log an error") + } +} + +func TestNoDate(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header()["Date"] = nil + })) + defer ts.Close() + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + _, present := res.Header["Date"] + if present { + t.Fatalf("Expected no Date header; got %v", res.Header["Date"]) + } +} + +func TestStripPrefix(t *testing.T) { + defer afterTest(t) + h := HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("X-Path", r.URL.Path) + }) + ts := httptest.NewServer(StripPrefix("/foo", h)) + defer ts.Close() + + res, err := Get(ts.URL + "/foo/bar") + if err != nil { + t.Fatal(err) + } + if g, e := res.Header.Get("X-Path"), "/bar"; g != e { + t.Errorf("test 1: got %s, want %s", g, e) + } + res.Body.Close() + + res, err = Get(ts.URL + "/bar") + if err != nil { + t.Fatal(err) + } + if g, e := res.StatusCode, 404; g != e { + t.Errorf("test 2: got status %v, want %v", g, e) + } + res.Body.Close() +} + +func TestRequestLimit(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + t.Fatalf("didn't expect to get request in Handler") + })) + defer ts.Close() + req, _ := NewRequest("GET", ts.URL, nil) + var bytesPerHeader = len("header12345: val12345\r\n") + for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ { + req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i)) + } + res, err := DefaultClient.Do(req) + if err != nil { + // Some HTTP clients may fail on this undefined behavior (server replying and + // closing the connection while the request is still being written), but + // we do support it (at least currently), so we expect a response below. + t.Fatalf("Do: %v", err) + } + defer res.Body.Close() + if res.StatusCode != 413 { + t.Fatalf("expected 413 response status; got: %d %s", res.StatusCode, res.Status) + } +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +type countReader struct { + r io.Reader + n *int64 +} + +func (cr countReader) Read(p []byte) (n int, err error) { + n, err = cr.r.Read(p) + atomic.AddInt64(cr.n, int64(n)) + return +} + +func TestRequestBodyLimit(t *testing.T) { + defer afterTest(t) + const limit = 1 << 20 + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + r.Body = MaxBytesReader(w, r.Body, limit) + n, err := io.Copy(ioutil.Discard, r.Body) + if err == nil { + t.Errorf("expected error from io.Copy") + } + if n != limit { + t.Errorf("io.Copy = %d, want %d", n, limit) + } + })) + defer ts.Close() + + nWritten := new(int64) + req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200)) + + // Send the POST, but don't care it succeeds or not. The + // remote side is going to reply and then close the TCP + // connection, and HTTP doesn't really define if that's + // allowed or not. Some HTTP clients will get the response + // and some (like ours, currently) will complain that the + // request write failed, without reading the response. + // + // But that's okay, since what we're really testing is that + // the remote side hung up on us before we wrote too much. + _, _ = DefaultClient.Do(req) + + if atomic.LoadInt64(nWritten) > limit*100 { + t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d", + limit, nWritten) + } +} + +// TestClientWriteShutdown tests that if the client shuts down the write +// side of their TCP connection, the server doesn't send a 400 Bad Request. +func TestClientWriteShutdown(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer ts.Close() + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + err = conn.(*net.TCPConn).CloseWrite() + if err != nil { + t.Fatalf("Dial: %v", err) + } + donec := make(chan bool) + go func() { + defer close(donec) + bs, err := ioutil.ReadAll(conn) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + got := string(bs) + if got != "" { + t.Errorf("read %q from server; want nothing", got) + } + }() + select { + case <-donec: + case <-time.After(10 * time.Second): + t.Fatalf("timeout") + } +} + +// Tests that chunked server responses that write 1 byte at a time are +// buffered before chunk headers are added, not after chunk headers. +func TestServerBufferedChunking(t *testing.T) { + conn := new(testConn) + conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n")) + conn.closec = make(chan bool, 1) + ls := &oneConnListener{conn} + go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { + rw.(Flusher).Flush() // force the Header to be sent, in chunking mode, not counting the length + rw.Write([]byte{'x'}) + rw.Write([]byte{'y'}) + rw.Write([]byte{'z'}) + })) + <-conn.closec + if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) { + t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q", + conn.writeBuf.Bytes()) + } +} + +// Tests that the server flushes its response headers out when it's +// ignoring the response body and waits a bit before forcefully +// closing the TCP connection, causing the client to get a RST. +// See http://golang.org/issue/3595 +func TestServerGracefulClose(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Error(w, "bye", StatusUnauthorized) + })) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + const bodySize = 5 << 20 + req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize)) + for i := 0; i < bodySize; i++ { + req = append(req, 'x') + } + writeErr := make(chan error) + go func() { + _, err := conn.Write(req) + writeErr <- err + }() + br := bufio.NewReader(conn) + lineNum := 0 + for { + line, err := br.ReadString('\n') + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("ReadLine: %v", err) + } + lineNum++ + if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") { + t.Errorf("Response line = %q; want a 401", line) + } + } + // Wait for write to finish. This is a broken pipe on both + // Darwin and Linux, but checking this isn't the point of + // the test. + <-writeErr +} + +func TestCaseSensitiveMethod(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "get" { + t.Errorf(`Got method %q; want "get"`, r.Method) + } + })) + defer ts.Close() + req, _ := NewRequest("get", ts.URL, nil) + res, err := DefaultClient.Do(req) + if err != nil { + t.Error(err) + return + } + res.Body.Close() +} + +// TestContentLengthZero tests that for both an HTTP/1.0 and HTTP/1.1 +// request (both keep-alive), when a Handler never writes any +// response, the net/http package adds a "Content-Length: 0" response +// header. +func TestContentLengthZero(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {})) + defer ts.Close() + + for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version) + if err != nil { + t.Fatalf("error writing: %v", err) + } + req, _ := NewRequest("GET", "/", nil) + res, err := ReadResponse(bufio.NewReader(conn), req) + if err != nil { + t.Fatalf("error reading response: %v", err) + } + if te := res.TransferEncoding; len(te) > 0 { + t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te) + } + if cl := res.ContentLength; cl != 0 { + t.Errorf("For version %q, Content-Length = %v; want 0", version, cl) + } + conn.Close() + } +} + +func TestCloseNotifier(t *testing.T) { + defer afterTest(t) + gotReq := make(chan bool, 1) + sawClose := make(chan bool, 1) + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + gotReq <- true + cc := rw.(CloseNotifier).CloseNotify() + <-cc + sawClose <- true + })) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + diec := make(chan bool) + go func() { + _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n") + if err != nil { + t.Fatal(err) + } + <-diec + conn.Close() + }() +For: + for { + select { + case <-gotReq: + diec <- true + case <-sawClose: + break For + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } + } + ts.Close() +} + +func TestCloseNotifierChanLeak(t *testing.T) { + defer afterTest(t) + req := reqBytes("GET / HTTP/1.0\nHost: golang.org") + for i := 0; i < 20; i++ { + var output bytes.Buffer + conn := &rwTestConn{ + Reader: bytes.NewReader(req), + Writer: &output, + closec: make(chan bool, 1), + } + ln := &oneConnListener{conn: conn} + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + // Ignore the return value and never read from + // it, testing that we don't leak goroutines + // on the sending side: + _ = rw.(CloseNotifier).CloseNotify() + }) + go Serve(ln, handler) + <-conn.closec + } +} + +func TestOptions(t *testing.T) { + uric := make(chan string, 2) // only expect 1, but leave space for 2 + mux := NewServeMux() + mux.HandleFunc("/", func(w ResponseWriter, r *Request) { + uric <- r.RequestURI + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // An OPTIONS * request should succeed. + _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n")) + if err != nil { + t.Fatal(err) + } + br := bufio.NewReader(conn) + res, err := ReadResponse(br, &Request{Method: "OPTIONS"}) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 200 { + t.Errorf("Got non-200 response to OPTIONS *: %#v", res) + } + + // A GET * request on a ServeMux should fail. + _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n")) + if err != nil { + t.Fatal(err) + } + res, err = ReadResponse(br, &Request{Method: "GET"}) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 400 { + t.Errorf("Got non-400 response to GET *: %#v", res) + } + + res, err = Get(ts.URL + "/second") + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if got := <-uric; got != "/second" { + t.Errorf("Handler saw request for %q; want /second", got) + } +} + +// Tests regarding the ordering of Write, WriteHeader, Header, and +// Flush calls. In Go 1.0, rw.WriteHeader immediately flushed the +// (*response).header to the wire. In Go 1.1, the actual wire flush is +// delayed, so we could maybe tack on a Content-Length and better +// Content-Type after we see more (or all) of the output. To preserve +// compatibility with Go 1, we need to be careful to track which +// headers were live at the time of WriteHeader, so we write the same +// ones, even if the handler modifies them (~erroneously) after the +// first Write. +func TestHeaderToWire(t *testing.T) { + tests := []struct { + name string + handler func(ResponseWriter, *Request) + check func(output string) error + }{ + { + name: "write without Header", + handler: func(rw ResponseWriter, r *Request) { + rw.Write([]byte("hello world")) + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Length:") { + return errors.New("no content-length") + } + if !strings.Contains(got, "Content-Type: text/plain") { + return errors.New("no content-length") + } + return nil + }, + }, + { + name: "Header mutation before write", + handler: func(rw ResponseWriter, r *Request) { + h := rw.Header() + h.Set("Content-Type", "some/type") + rw.Write([]byte("hello world")) + h.Set("Too-Late", "bogus") + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Length:") { + return errors.New("no content-length") + } + if !strings.Contains(got, "Content-Type: some/type") { + return errors.New("wrong content-type") + } + if strings.Contains(got, "Too-Late") { + return errors.New("don't want too-late header") + } + return nil + }, + }, + { + name: "write then useless Header mutation", + handler: func(rw ResponseWriter, r *Request) { + rw.Write([]byte("hello world")) + rw.Header().Set("Too-Late", "Write already wrote headers") + }, + check: func(got string) error { + if strings.Contains(got, "Too-Late") { + return errors.New("header appeared from after WriteHeader") + } + return nil + }, + }, + { + name: "flush then write", + handler: func(rw ResponseWriter, r *Request) { + rw.(Flusher).Flush() + rw.Write([]byte("post-flush")) + rw.Header().Set("Too-Late", "Write already wrote headers") + }, + check: func(got string) error { + if !strings.Contains(got, "Transfer-Encoding: chunked") { + return errors.New("not chunked") + } + if strings.Contains(got, "Too-Late") { + return errors.New("header appeared from after WriteHeader") + } + return nil + }, + }, + { + name: "header then flush", + handler: func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "some/type") + rw.(Flusher).Flush() + rw.Write([]byte("post-flush")) + rw.Header().Set("Too-Late", "Write already wrote headers") + }, + check: func(got string) error { + if !strings.Contains(got, "Transfer-Encoding: chunked") { + return errors.New("not chunked") + } + if strings.Contains(got, "Too-Late") { + return errors.New("header appeared from after WriteHeader") + } + if !strings.Contains(got, "Content-Type: some/type") { + return errors.New("wrong content-length") + } + return nil + }, + }, + { + name: "sniff-on-first-write content-type", + handler: func(rw ResponseWriter, r *Request) { + rw.Write([]byte("<html><head></head><body>some html</body></html>")) + rw.Header().Set("Content-Type", "x/wrong") + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Type: text/html") { + return errors.New("wrong content-length; want html") + } + return nil + }, + }, + { + name: "explicit content-type wins", + handler: func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "some/type") + rw.Write([]byte("<html><head></head><body>some html</body></html>")) + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Type: some/type") { + return errors.New("wrong content-length; want html") + } + return nil + }, + }, + { + name: "empty handler", + handler: func(rw ResponseWriter, r *Request) { + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Type: text/plain") { + return errors.New("wrong content-length; want text/plain") + } + if !strings.Contains(got, "Content-Length: 0") { + return errors.New("want 0 content-length") + } + return nil + }, + }, + { + name: "only Header, no write", + handler: func(rw ResponseWriter, r *Request) { + rw.Header().Set("Some-Header", "some-value") + }, + check: func(got string) error { + if !strings.Contains(got, "Some-Header") { + return errors.New("didn't get header") + } + return nil + }, + }, + { + name: "WriteHeader call", + handler: func(rw ResponseWriter, r *Request) { + rw.WriteHeader(404) + rw.Header().Set("Too-Late", "some-value") + }, + check: func(got string) error { + if !strings.Contains(got, "404") { + return errors.New("wrong status") + } + if strings.Contains(got, "Some-Header") { + return errors.New("shouldn't have seen Too-Late") + } + return nil + }, + }, + } + for _, tc := range tests { + ht := newHandlerTest(HandlerFunc(tc.handler)) + got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org") + if err := tc.check(got); err != nil { + t.Errorf("%s: %v\nGot response:\n%s", tc.name, err, got) + } + } +} + +// goTimeout runs f, failing t if f takes more than ns to complete. +func goTimeout(t *testing.T, d time.Duration, f func()) { + ch := make(chan bool, 2) + timer := time.AfterFunc(d, func() { + t.Errorf("Timeout expired after %v", d) + ch <- true + }) + defer timer.Stop() + go func() { + defer func() { ch <- true }() + f() + }() + <-ch +} + +type errorListener struct { + errs []error +} + +func (l *errorListener) Accept() (c net.Conn, err error) { + if len(l.errs) == 0 { + return nil, io.EOF + } + err = l.errs[0] + l.errs = l.errs[1:] + return +} + +func (l *errorListener) Close() error { + return nil +} + +func (l *errorListener) Addr() net.Addr { + return dummyAddr("test-address") +} + +func TestAcceptMaxFds(t *testing.T) { + log.SetOutput(ioutil.Discard) // is noisy otherwise + defer log.SetOutput(os.Stderr) + + ln := &errorListener{[]error{ + &net.OpError{ + Op: "accept", + Err: syscall.EMFILE, + }}} + err := Serve(ln, HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {}))) + if err != io.EOF { + t.Errorf("got error %v, want EOF", err) + } +} + +func TestWriteAfterHijack(t *testing.T) { + req := reqBytes("GET / HTTP/1.1\nHost: golang.org") + var buf bytes.Buffer + wrotec := make(chan bool, 1) + conn := &rwTestConn{ + Reader: bytes.NewReader(req), + Writer: &buf, + closec: make(chan bool, 1), + } + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + conn, bufrw, err := rw.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + go func() { + bufrw.Write([]byte("[hijack-to-bufw]")) + bufrw.Flush() + conn.Write([]byte("[hijack-to-conn]")) + conn.Close() + wrotec <- true + }() + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec + <-wrotec + if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w { + t.Errorf("wrote %q; want %q", g, w) + } +} + +func TestDoubleHijack(t *testing.T) { + req := reqBytes("GET / HTTP/1.1\nHost: golang.org") + var buf bytes.Buffer + conn := &rwTestConn{ + Reader: bytes.NewReader(req), + Writer: &buf, + closec: make(chan bool, 1), + } + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + conn, _, err := rw.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + _, _, err = rw.(Hijacker).Hijack() + if err == nil { + t.Errorf("got err = nil; want err != nil") + } + conn.Close() + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec +} + +// http://code.google.com/p/go/issues/detail?id=5955 +// Note that this does not test the "request too large" +// exit path from the http server. This is intentional; +// not sending Connection: close is just a minor wire +// optimization and is pointless if dealing with a +// badly behaved client. +func TestHTTP10ConnectionHeader(t *testing.T) { + defer afterTest(t) + + mux := NewServeMux() + mux.Handle("/", HandlerFunc(func(resp ResponseWriter, req *Request) {})) + ts := httptest.NewServer(mux) + defer ts.Close() + + // net/http uses HTTP/1.1 for requests, so write requests manually + tests := []struct { + req string // raw http request + expect []string // expected Connection header(s) + }{ + { + req: "GET / HTTP/1.0\r\n\r\n", + expect: nil, + }, + { + req: "OPTIONS * HTTP/1.0\r\n\r\n", + expect: nil, + }, + { + req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", + expect: []string{"keep-alive"}, + }, + } + + for _, tt := range tests { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal("dial err:", err) + } + + _, err = fmt.Fprint(conn, tt.req) + if err != nil { + t.Fatal("conn write err:", err) + } + + resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"}) + if err != nil { + t.Fatal("ReadResponse err:", err) + } + conn.Close() + resp.Body.Close() + + got := resp.Header["Connection"] + if !reflect.DeepEqual(got, tt.expect) { + t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect) + } + } +} + +// See golang.org/issue/5660 +func TestServerReaderFromOrder(t *testing.T) { + defer afterTest(t) + pr, pw := io.Pipe() + const size = 3 << 20 + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + rw.Header().Set("Content-Type", "text/plain") // prevent sniffing path + done := make(chan bool) + go func() { + io.Copy(rw, pr) + close(done) + }() + time.Sleep(25 * time.Millisecond) // give Copy a chance to break things + n, err := io.Copy(ioutil.Discard, req.Body) + if err != nil { + t.Errorf("handler Copy: %v", err) + return + } + if n != size { + t.Errorf("handler Copy = %d; want %d", n, size) + } + pw.Write([]byte("hi")) + pw.Close() + <-done + })) + defer ts.Close() + + req, err := NewRequest("POST", ts.URL, io.LimitReader(neverEnding('a'), size)) + if err != nil { + t.Fatal(err) + } + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + all, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if string(all) != "hi" { + t.Errorf("Body = %q; want hi", all) + } +} + +// Issue 6157, Issue 6685 +func TestCodesPreventingContentTypeAndBody(t *testing.T) { + for _, code := range []int{StatusNotModified, StatusNoContent, StatusContinue} { + ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/header" { + w.Header().Set("Content-Length", "123") + } + w.WriteHeader(code) + if r.URL.Path == "/more" { + w.Write([]byte("stuff")) + } + })) + for _, req := range []string{ + "GET / HTTP/1.0", + "GET /header HTTP/1.0", + "GET /more HTTP/1.0", + "GET / HTTP/1.1", + "GET /header HTTP/1.1", + "GET /more HTTP/1.1", + } { + got := ht.rawResponse(req) + wantStatus := fmt.Sprintf("%d %s", code, StatusText(code)) + if !strings.Contains(got, wantStatus) { + t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got) + } else if strings.Contains(got, "Content-Length") { + t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got) + } else if strings.Contains(got, "stuff") { + t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got) + } + } + } +} + +func TestContentTypeOkayOn204(t *testing.T) { + ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "123") // suppressed + w.Header().Set("Content-Type", "foo/bar") + w.WriteHeader(204) + })) + got := ht.rawResponse("GET / HTTP/1.1") + if !strings.Contains(got, "Content-Type: foo/bar") { + t.Errorf("Response = %q; want Content-Type: foo/bar", got) + } + if strings.Contains(got, "Content-Length: 123") { + t.Errorf("Response = %q; don't want a Content-Length", got) + } +} + +// Issue 6995 +// A server Handler can receive a Request, and then turn around and +// give a copy of that Request.Body out to the Transport (e.g. any +// proxy). So then two people own that Request.Body (both the server +// and the http client), and both think they can close it on failure. +// Therefore, all incoming server requests Bodies need to be thread-safe. +func TestTransportAndServerSharedBodyRace(t *testing.T) { + defer afterTest(t) + + const bodySize = 1 << 20 + + unblockBackend := make(chan bool) + backend := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + io.CopyN(rw, req.Body, bodySize/2) + <-unblockBackend + })) + defer backend.Close() + + backendRespc := make(chan *Response, 1) + proxy := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + if req.RequestURI == "/foo" { + rw.Write([]byte("bar")) + return + } + req2, _ := NewRequest("POST", backend.URL, req.Body) + req2.ContentLength = bodySize + + bresp, err := DefaultClient.Do(req2) + if err != nil { + t.Errorf("Proxy outbound request: %v", err) + return + } + _, err = io.CopyN(ioutil.Discard, bresp.Body, bodySize/4) + if err != nil { + t.Errorf("Proxy copy error: %v", err) + return + } + backendRespc <- bresp // to close later + + // Try to cause a race: Both the DefaultTransport and the proxy handler's Server + // will try to read/close req.Body (aka req2.Body) + DefaultTransport.(*Transport).CancelRequest(req2) + rw.Write([]byte("OK")) + })) + defer proxy.Close() + + req, _ := NewRequest("POST", proxy.URL, io.LimitReader(neverEnding('a'), bodySize)) + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatalf("Original request: %v", err) + } + + // Cleanup, so we don't leak goroutines. + res.Body.Close() + close(unblockBackend) + (<-backendRespc).Body.Close() +} + +// Test that a hanging Request.Body.Read from another goroutine can't +// cause the Handler goroutine's Request.Body.Close to block. +func TestRequestBodyCloseDoesntBlock(t *testing.T) { + t.Skipf("Skipping known issue; see golang.org/issue/7121") + if testing.Short() { + t.Skip("skipping in -short mode") + } + defer afterTest(t) + + readErrCh := make(chan error, 1) + errCh := make(chan error, 2) + + server := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + go func(body io.Reader) { + _, err := body.Read(make([]byte, 100)) + readErrCh <- err + }(req.Body) + time.Sleep(500 * time.Millisecond) + })) + defer server.Close() + + closeConn := make(chan bool) + defer close(closeConn) + go func() { + conn, err := net.Dial("tcp", server.Listener.Addr().String()) + if err != nil { + errCh <- err + return + } + defer conn.Close() + _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n")) + if err != nil { + errCh <- err + return + } + // And now just block, making the server block on our + // 100000 bytes of body that will never arrive. + <-closeConn + }() + select { + case err := <-readErrCh: + if err == nil { + t.Error("Read was nil. Expected error.") + } + case err := <-errCh: + t.Error(err) + case <-time.After(5 * time.Second): + t.Error("timeout") + } +} + +func TestResponseWriterWriteStringAllocs(t *testing.T) { + ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/s" { + io.WriteString(w, "Hello world") + } else { + w.Write([]byte("Hello world")) + } + })) + before := testing.AllocsPerRun(50, func() { ht.rawResponse("GET / HTTP/1.0") }) + after := testing.AllocsPerRun(50, func() { ht.rawResponse("GET /s HTTP/1.0") }) + if int(after) >= int(before) { + t.Errorf("WriteString allocs of %v >= Write allocs of %v", after, before) + } +} + +func TestAppendTime(t *testing.T) { + var b [len(TimeFormat)]byte + t1 := time.Date(2013, 9, 21, 15, 41, 0, 0, time.FixedZone("CEST", 2*60*60)) + res := ExportAppendTime(b[:0], t1) + t2, err := ParseTime(string(res)) + if err != nil { + t.Fatalf("Error parsing time: %s", err) + } + if !t1.Equal(t2) { + t.Fatalf("Times differ; expected: %v, got %v (%s)", t1, t2, string(res)) + } +} + +func TestServerConnState(t *testing.T) { + defer afterTest(t) + handler := map[string]func(w ResponseWriter, r *Request){ + "/": func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hello.") + }, + "/close": func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + fmt.Fprintf(w, "Hello.") + }, + "/hijack": func(w ResponseWriter, r *Request) { + c, _, _ := w.(Hijacker).Hijack() + c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello.")) + c.Close() + }, + "/hijack-panic": func(w ResponseWriter, r *Request) { + c, _, _ := w.(Hijacker).Hijack() + c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello.")) + c.Close() + panic("intentional panic") + }, + } + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + handler[r.URL.Path](w, r) + })) + defer ts.Close() + + var mu sync.Mutex // guard stateLog and connID + var stateLog = map[int][]ConnState{} + var connID = map[net.Conn]int{} + + ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) + ts.Config.ConnState = func(c net.Conn, state ConnState) { + if c == nil { + t.Errorf("nil conn seen in state %s", state) + return + } + mu.Lock() + defer mu.Unlock() + id, ok := connID[c] + if !ok { + id = len(connID) + 1 + connID[c] = id + } + stateLog[id] = append(stateLog[id], state) + } + ts.Start() + + mustGet(t, ts.URL+"/") + mustGet(t, ts.URL+"/close") + + mustGet(t, ts.URL+"/") + mustGet(t, ts.URL+"/", "Connection", "close") + + mustGet(t, ts.URL+"/hijack") + mustGet(t, ts.URL+"/hijack-panic") + + // New->Closed + { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + c.Close() + } + + // New->Active->Closed + { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil { + t.Fatal(err) + } + c.Close() + } + + // New->Idle->Closed + { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil { + t.Fatal(err) + } + res, err := ReadResponse(bufio.NewReader(c), nil) + if err != nil { + t.Fatal(err) + } + if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + t.Fatal(err) + } + c.Close() + } + + want := map[int][]ConnState{ + 1: {StateNew, StateActive, StateIdle, StateActive, StateClosed}, + 2: {StateNew, StateActive, StateIdle, StateActive, StateClosed}, + 3: {StateNew, StateActive, StateHijacked}, + 4: {StateNew, StateActive, StateHijacked}, + 5: {StateNew, StateClosed}, + 6: {StateNew, StateActive, StateClosed}, + 7: {StateNew, StateActive, StateIdle, StateClosed}, + } + logString := func(m map[int][]ConnState) string { + var b bytes.Buffer + for id, l := range m { + fmt.Fprintf(&b, "Conn %d: ", id) + for _, s := range l { + fmt.Fprintf(&b, "%s ", s) + } + b.WriteString("\n") + } + return b.String() + } + + for i := 0; i < 5; i++ { + time.Sleep(time.Duration(i) * 50 * time.Millisecond) + mu.Lock() + match := reflect.DeepEqual(stateLog, want) + mu.Unlock() + if match { + return + } + } + + mu.Lock() + t.Errorf("Unexpected events.\nGot log: %s\n Want: %s\n", logString(stateLog), logString(want)) + mu.Unlock() +} + +func mustGet(t *testing.T, url string, headers ...string) { + req, err := NewRequest("GET", url, nil) + if err != nil { + t.Fatal(err) + } + for len(headers) > 0 { + req.Header.Add(headers[0], headers[1]) + headers = headers[2:] + } + res, err := DefaultClient.Do(req) + if err != nil { + t.Errorf("Error fetching %s: %v", url, err) + return + } + _, err = ioutil.ReadAll(res.Body) + defer res.Body.Close() + if err != nil { + t.Errorf("Error reading %s: %v", url, err) + } +} + +func TestServerKeepAlivesEnabled(t *testing.T) { + defer afterTest(t) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + ts.Config.SetKeepAlivesEnabled(false) + ts.Start() + defer ts.Close() + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if !res.Close { + t.Errorf("Body.Close == false; want true") + } +} + +// golang.org/issue/7856 +func TestServerEmptyBodyRace(t *testing.T) { + defer afterTest(t) + var n int32 + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + atomic.AddInt32(&n, 1) + })) + defer ts.Close() + var wg sync.WaitGroup + const reqs = 20 + for i := 0; i < reqs; i++ { + wg.Add(1) + go func() { + defer wg.Done() + res, err := Get(ts.URL) + if err != nil { + t.Error(err) + return + } + defer res.Body.Close() + _, err = io.Copy(ioutil.Discard, res.Body) + if err != nil { + t.Error(err) + return + } + }() + } + wg.Wait() + if got := atomic.LoadInt32(&n); got != reqs { + t.Errorf("handler ran %d times; want %d", got, reqs) + } +} + +func TestServerConnStateNew(t *testing.T) { + sawNew := false // if the test is buggy, we'll race on this variable. + srv := &Server{ + ConnState: func(c net.Conn, state ConnState) { + if state == StateNew { + sawNew = true // testing that this write isn't racy + } + }, + Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}), // irrelevant + } + srv.Serve(&oneConnListener{ + conn: &rwTestConn{ + Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"), + Writer: ioutil.Discard, + }, + }) + if !sawNew { // testing that this read isn't racy + t.Error("StateNew not seen") + } +} + +type closeWriteTestConn struct { + rwTestConn + didCloseWrite bool +} + +func (c *closeWriteTestConn) CloseWrite() error { + c.didCloseWrite = true + return nil +} + +func TestCloseWrite(t *testing.T) { + var srv Server + var testConn closeWriteTestConn + c, err := ExportServerNewConn(&srv, &testConn) + if err != nil { + t.Fatal(err) + } + ExportCloseWriteAndWait(c) + if !testConn.didCloseWrite { + t.Error("didn't see CloseWrite call") + } +} + +// This verifies that a handler can Flush and then Hijack. +// +// An similar test crashed once during development, but it was only +// testing this tangentially and temporarily until another TODO was +// fixed. +// +// So add an explicit test for this. +func TestServerFlushAndHijack(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, "Hello, ") + w.(Flusher).Flush() + conn, buf, _ := w.(Hijacker).Hijack() + buf.WriteString("6\r\nworld!\r\n0\r\n\r\n") + if err := buf.Flush(); err != nil { + t.Error(err) + } + if err := conn.Close(); err != nil { + t.Error(err) + } + })) + defer ts.Close() + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + all, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if want := "Hello, world!"; string(all) != want { + t.Errorf("Got %q; want %q", all, want) + } +} + +// golang.org/issue/8534 -- the Server shouldn't reuse a connection +// for keep-alive after it's seen any Write error (e.g. a timeout) on +// that net.Conn. +// +// To test, verify we don't timeout or see fewer unique client +// addresses (== unique connections) than requests. +func TestServerKeepAliveAfterWriteError(t *testing.T) { + if testing.Short() { + t.Skip("skipping in -short mode") + } + defer afterTest(t) + const numReq = 3 + addrc := make(chan string, numReq) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + addrc <- r.RemoteAddr + time.Sleep(500 * time.Millisecond) + w.(Flusher).Flush() + })) + ts.Config.WriteTimeout = 250 * time.Millisecond + ts.Start() + defer ts.Close() + + errc := make(chan error, numReq) + go func() { + defer close(errc) + for i := 0; i < numReq; i++ { + res, err := Get(ts.URL) + if res != nil { + res.Body.Close() + } + errc <- err + } + }() + + timeout := time.NewTimer(numReq * 2 * time.Second) // 4x overkill + defer timeout.Stop() + addrSeen := map[string]bool{} + numOkay := 0 + for { + select { + case v := <-addrc: + addrSeen[v] = true + case err, ok := <-errc: + if !ok { + if len(addrSeen) != numReq { + t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq) + } + if numOkay != 0 { + t.Errorf("got %d successful client requests; want 0", numOkay) + } + return + } + if err == nil { + numOkay++ + } + case <-timeout.C: + t.Fatal("timeout waiting for requests to complete") + } + } +} + +func BenchmarkClientServer(b *testing.B) { + b.ReportAllocs() + b.StopTimer() + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + fmt.Fprintf(rw, "Hello world.\n") + })) + defer ts.Close() + b.StartTimer() + + for i := 0; i < b.N; i++ { + res, err := Get(ts.URL) + if err != nil { + b.Fatal("Get:", err) + } + all, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + b.Fatal("ReadAll:", err) + } + body := string(all) + if body != "Hello world.\n" { + b.Fatal("Got body:", body) + } + } + + b.StopTimer() +} + +func BenchmarkClientServerParallel4(b *testing.B) { + benchmarkClientServerParallel(b, 4, false) +} + +func BenchmarkClientServerParallel64(b *testing.B) { + benchmarkClientServerParallel(b, 64, false) +} + +func BenchmarkClientServerParallelTLS4(b *testing.B) { + benchmarkClientServerParallel(b, 4, true) +} + +func BenchmarkClientServerParallelTLS64(b *testing.B) { + benchmarkClientServerParallel(b, 64, true) +} + +func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) { + b.ReportAllocs() + ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + fmt.Fprintf(rw, "Hello world.\n") + })) + if useTLS { + ts.StartTLS() + } else { + ts.Start() + } + defer ts.Close() + b.ResetTimer() + b.SetParallelism(parallelism) + b.RunParallel(func(pb *testing.PB) { + noVerifyTransport := &Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + defer noVerifyTransport.CloseIdleConnections() + client := &Client{Transport: noVerifyTransport} + for pb.Next() { + res, err := client.Get(ts.URL) + if err != nil { + b.Logf("Get: %v", err) + continue + } + all, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + b.Logf("ReadAll: %v", err) + continue + } + body := string(all) + if body != "Hello world.\n" { + panic("Got body: " + body) + } + } + }) +} + +// A benchmark for profiling the server without the HTTP client code. +// The client code runs in a subprocess. +// +// For use like: +// $ go test -c +// $ ./http.test -test.run=XX -test.bench=BenchmarkServer -test.benchtime=15s -test.cpuprofile=http.prof +// $ go tool pprof http.test http.prof +// (pprof) web +func BenchmarkServer(b *testing.B) { + b.ReportAllocs() + // Child process mode; + if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" { + n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N")) + if err != nil { + panic(err) + } + for i := 0; i < n; i++ { + res, err := Get(url) + if err != nil { + log.Panicf("Get: %v", err) + } + all, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Panicf("ReadAll: %v", err) + } + body := string(all) + if body != "Hello world.\n" { + log.Panicf("Got body: %q", body) + } + } + os.Exit(0) + return + } + + var res = []byte("Hello world.\n") + b.StopTimer() + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.Write(res) + })) + defer ts.Close() + b.StartTimer() + + cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkServer") + cmd.Env = append([]string{ + fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N), + fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL), + }, os.Environ()...) + out, err := cmd.CombinedOutput() + if err != nil { + b.Errorf("Test failure: %v, with output: %s", err, out) + } +} + +func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) { + b.ReportAllocs() + req := reqBytes(`GET / HTTP/1.0 +Host: golang.org +Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 +User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 +`) + res := []byte("Hello world!\n") + + conn := &testConn{ + // testConn.Close will not push into the channel + // if it's full. + closec: make(chan bool, 1), + } + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.Write(res) + }) + ln := new(oneConnListener) + for i := 0; i < b.N; i++ { + conn.readBuf.Reset() + conn.writeBuf.Reset() + conn.readBuf.Write(req) + ln.conn = conn + Serve(ln, handler) + <-conn.closec + } +} + +// repeatReader reads content count times, then EOFs. +type repeatReader struct { + content []byte + count int + off int +} + +func (r *repeatReader) Read(p []byte) (n int, err error) { + if r.count <= 0 { + return 0, io.EOF + } + n = copy(p, r.content[r.off:]) + r.off += n + if r.off == len(r.content) { + r.count-- + r.off = 0 + } + return +} + +func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) { + b.ReportAllocs() + + req := reqBytes(`GET / HTTP/1.1 +Host: golang.org +Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 +User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 +`) + res := []byte("Hello world!\n") + + conn := &rwTestConn{ + Reader: &repeatReader{content: req, count: b.N}, + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + handled := 0 + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + handled++ + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.Write(res) + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec + if b.N != handled { + b.Errorf("b.N=%d but handled %d", b.N, handled) + } +} + +// same as above, but representing the most simple possible request +// and handler. Notably: the handler does not call rw.Header(). +func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) { + b.ReportAllocs() + + req := reqBytes(`GET / HTTP/1.1 +Host: golang.org +`) + res := []byte("Hello world!\n") + + conn := &rwTestConn{ + Reader: &repeatReader{content: req, count: b.N}, + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + handled := 0 + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + handled++ + rw.Write(res) + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec + if b.N != handled { + b.Errorf("b.N=%d but handled %d", b.N, handled) + } +} + +const someResponse = "<html>some response</html>" + +// A Response that's just no bigger than 2KB, the buffer-before-chunking threshold. +var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse)) + +// Both Content-Type and Content-Length set. Should be no buffering. +func BenchmarkServerHandlerTypeLen(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Content-Length", strconv.Itoa(len(response))) + w.Write(response) + })) +} + +// A Content-Type is set, but no length. No sniffing, but will count the Content-Length. +func BenchmarkServerHandlerNoLen(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Type", "text/html") + w.Write(response) + })) +} + +// A Content-Length is set, but the Content-Type will be sniffed. +func BenchmarkServerHandlerNoType(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", strconv.Itoa(len(response))) + w.Write(response) + })) +} + +// Neither a Content-Type or Content-Length, so sniffed and counted. +func BenchmarkServerHandlerNoHeader(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write(response) + })) +} + +func benchmarkHandler(b *testing.B, h Handler) { + b.ReportAllocs() + req := reqBytes(`GET / HTTP/1.1 +Host: golang.org +`) + conn := &rwTestConn{ + Reader: &repeatReader{content: req, count: b.N}, + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + handled := 0 + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + handled++ + h.ServeHTTP(rw, r) + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec + if b.N != handled { + b.Errorf("b.N=%d but handled %d", b.N, handled) + } +} + +func BenchmarkServerHijack(b *testing.B) { + b.ReportAllocs() + req := reqBytes(`GET / HTTP/1.1 +Host: golang.org +`) + h := HandlerFunc(func(w ResponseWriter, r *Request) { + conn, _, err := w.(Hijacker).Hijack() + if err != nil { + panic(err) + } + conn.Close() + }) + conn := &rwTestConn{ + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + ln := &oneConnListener{conn: conn} + for i := 0; i < b.N; i++ { + conn.Reader = bytes.NewReader(req) + ln.conn = conn + Serve(ln, h) + <-conn.closec + } +} diff --git a/src/net/http/server.go b/src/net/http/server.go new file mode 100644 index 000000000..008d5aa7a --- /dev/null +++ b/src/net/http/server.go @@ -0,0 +1,2096 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP server. See RFC 2616. + +package http + +import ( + "bufio" + "crypto/tls" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/url" + "os" + "path" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +// Errors introduced by the HTTP server. +var ( + ErrWriteAfterFlush = errors.New("Conn.Write called after Flush") + ErrBodyNotAllowed = errors.New("http: request method or response status code does not allow body") + ErrHijacked = errors.New("Conn has been hijacked") + ErrContentLength = errors.New("Conn.Write wrote more than the declared Content-Length") +) + +// Objects implementing the Handler interface can be +// registered to serve a particular path or subtree +// in the HTTP server. +// +// ServeHTTP should write reply headers and data to the ResponseWriter +// and then return. Returning signals that the request is finished +// and that the HTTP server can move on to the next request on +// the connection. +// +// If ServeHTTP panics, the server (the caller of ServeHTTP) assumes +// that the effect of the panic was isolated to the active request. +// It recovers the panic, logs a stack trace to the server error log, +// and hangs up the connection. +// +type Handler interface { + ServeHTTP(ResponseWriter, *Request) +} + +// A ResponseWriter interface is used by an HTTP handler to +// construct an HTTP response. +type ResponseWriter interface { + // Header returns the header map that will be sent by WriteHeader. + // Changing the header after a call to WriteHeader (or Write) has + // no effect. + Header() Header + + // Write writes the data to the connection as part of an HTTP reply. + // If WriteHeader has not yet been called, Write calls WriteHeader(http.StatusOK) + // before writing the data. If the Header does not contain a + // Content-Type line, Write adds a Content-Type set to the result of passing + // the initial 512 bytes of written data to DetectContentType. + Write([]byte) (int, error) + + // WriteHeader sends an HTTP response header with status code. + // If WriteHeader is not called explicitly, the first call to Write + // will trigger an implicit WriteHeader(http.StatusOK). + // Thus explicit calls to WriteHeader are mainly used to + // send error codes. + WriteHeader(int) +} + +// The Flusher interface is implemented by ResponseWriters that allow +// an HTTP handler to flush buffered data to the client. +// +// Note that even for ResponseWriters that support Flush, +// if the client is connected through an HTTP proxy, +// the buffered data may not reach the client until the response +// completes. +type Flusher interface { + // Flush sends any buffered data to the client. + Flush() +} + +// The Hijacker interface is implemented by ResponseWriters that allow +// an HTTP handler to take over the connection. +type Hijacker interface { + // Hijack lets the caller take over the connection. + // After a call to Hijack(), the HTTP server library + // will not do anything else with the connection. + // It becomes the caller's responsibility to manage + // and close the connection. + Hijack() (net.Conn, *bufio.ReadWriter, error) +} + +// The CloseNotifier interface is implemented by ResponseWriters which +// allow detecting when the underlying connection has gone away. +// +// This mechanism can be used to cancel long operations on the server +// if the client has disconnected before the response is ready. +type CloseNotifier interface { + // CloseNotify returns a channel that receives a single value + // when the client connection has gone away. + CloseNotify() <-chan bool +} + +// A conn represents the server side of an HTTP connection. +type conn struct { + remoteAddr string // network address of remote side + server *Server // the Server on which the connection arrived + rwc net.Conn // i/o connection + w io.Writer // checkConnErrorWriter's copy of wrc, not zeroed on Hijack + werr error // any errors writing to w + sr liveSwitchReader // where the LimitReader reads from; usually the rwc + lr *io.LimitedReader // io.LimitReader(sr) + buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc + tlsState *tls.ConnectionState // or nil when not using TLS + + mu sync.Mutex // guards the following + clientGone bool // if client has disconnected mid-request + closeNotifyc chan bool // made lazily + hijackedv bool // connection has been hijacked by handler +} + +func (c *conn) hijacked() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.hijackedv +} + +func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.hijackedv { + return nil, nil, ErrHijacked + } + if c.closeNotifyc != nil { + return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier") + } + c.hijackedv = true + rwc = c.rwc + buf = c.buf + c.rwc = nil + c.buf = nil + c.setState(rwc, StateHijacked) + return +} + +func (c *conn) closeNotify() <-chan bool { + c.mu.Lock() + defer c.mu.Unlock() + if c.closeNotifyc == nil { + c.closeNotifyc = make(chan bool, 1) + if c.hijackedv { + // to obey the function signature, even though + // it'll never receive a value. + return c.closeNotifyc + } + pr, pw := io.Pipe() + + readSource := c.sr.r + c.sr.Lock() + c.sr.r = pr + c.sr.Unlock() + go func() { + _, err := io.Copy(pw, readSource) + if err == nil { + err = io.EOF + } + pw.CloseWithError(err) + c.noteClientGone() + }() + } + return c.closeNotifyc +} + +func (c *conn) noteClientGone() { + c.mu.Lock() + defer c.mu.Unlock() + if c.closeNotifyc != nil && !c.clientGone { + c.closeNotifyc <- true + } + c.clientGone = true +} + +// A switchReader can have its Reader changed at runtime. +// It's not safe for concurrent Reads and switches. +type switchReader struct { + io.Reader +} + +// A switchWriter can have its Writer changed at runtime. +// It's not safe for concurrent Writes and switches. +type switchWriter struct { + io.Writer +} + +// A liveSwitchReader is a switchReader that's safe for concurrent +// reads and switches, if its mutex is held. +type liveSwitchReader struct { + sync.Mutex + r io.Reader +} + +func (sr *liveSwitchReader) Read(p []byte) (n int, err error) { + sr.Lock() + r := sr.r + sr.Unlock() + return r.Read(p) +} + +// This should be >= 512 bytes for DetectContentType, +// but otherwise it's somewhat arbitrary. +const bufferBeforeChunkingSize = 2048 + +// chunkWriter writes to a response's conn buffer, and is the writer +// wrapped by the response.bufw buffered writer. +// +// chunkWriter also is responsible for finalizing the Header, including +// conditionally setting the Content-Type and setting a Content-Length +// in cases where the handler's final output is smaller than the buffer +// size. It also conditionally adds chunk headers, when in chunking mode. +// +// See the comment above (*response).Write for the entire write flow. +type chunkWriter struct { + res *response + + // header is either nil or a deep clone of res.handlerHeader + // at the time of res.WriteHeader, if res.WriteHeader is + // called and extra buffering is being done to calculate + // Content-Type and/or Content-Length. + header Header + + // wroteHeader tells whether the header's been written to "the + // wire" (or rather: w.conn.buf). this is unlike + // (*response).wroteHeader, which tells only whether it was + // logically written. + wroteHeader bool + + // set by the writeHeader method: + chunking bool // using chunked transfer encoding for reply body +} + +var ( + crlf = []byte("\r\n") + colonSpace = []byte(": ") +) + +func (cw *chunkWriter) Write(p []byte) (n int, err error) { + if !cw.wroteHeader { + cw.writeHeader(p) + } + if cw.res.req.Method == "HEAD" { + // Eat writes. + return len(p), nil + } + if cw.chunking { + _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p)) + if err != nil { + cw.res.conn.rwc.Close() + return + } + } + n, err = cw.res.conn.buf.Write(p) + if cw.chunking && err == nil { + _, err = cw.res.conn.buf.Write(crlf) + } + if err != nil { + cw.res.conn.rwc.Close() + } + return +} + +func (cw *chunkWriter) flush() { + if !cw.wroteHeader { + cw.writeHeader(nil) + } + cw.res.conn.buf.Flush() +} + +func (cw *chunkWriter) close() { + if !cw.wroteHeader { + cw.writeHeader(nil) + } + if cw.chunking { + // zero EOF chunk, trailer key/value pairs (currently + // unsupported in Go's server), followed by a blank + // line. + cw.res.conn.buf.WriteString("0\r\n\r\n") + } +} + +// A response represents the server side of an HTTP response. +type response struct { + conn *conn + req *Request // request for this response + wroteHeader bool // reply header has been (logically) written + wroteContinue bool // 100 Continue response was written + + w *bufio.Writer // buffers output in chunks to chunkWriter + cw chunkWriter + sw *switchWriter // of the bufio.Writer, for return to putBufioWriter + + // handlerHeader is the Header that Handlers get access to, + // which may be retained and mutated even after WriteHeader. + // handlerHeader is copied into cw.header at WriteHeader + // time, and privately mutated thereafter. + handlerHeader Header + calledHeader bool // handler accessed handlerHeader via Header + + written int64 // number of bytes written in body + contentLength int64 // explicitly-declared Content-Length; or -1 + status int // status code passed to WriteHeader + + // close connection after this reply. set on request and + // updated after response from handler if there's a + // "Connection: keep-alive" response header and a + // Content-Length. + closeAfterReply bool + + // requestBodyLimitHit is set by requestTooLarge when + // maxBytesReader hits its max size. It is checked in + // WriteHeader, to make sure we don't consume the + // remaining request body to try to advance to the next HTTP + // request. Instead, when this is set, we stop reading + // subsequent requests on this connection and stop reading + // input from it. + requestBodyLimitHit bool + + handlerDone bool // set true when the handler exits + + // Buffers for Date and Content-Length + dateBuf [len(TimeFormat)]byte + clenBuf [10]byte +} + +// requestTooLarge is called by maxBytesReader when too much input has +// been read from the client. +func (w *response) requestTooLarge() { + w.closeAfterReply = true + w.requestBodyLimitHit = true + if !w.wroteHeader { + w.Header().Set("Connection", "close") + } +} + +// needsSniff reports whether a Content-Type still needs to be sniffed. +func (w *response) needsSniff() bool { + _, haveType := w.handlerHeader["Content-Type"] + return !w.cw.wroteHeader && !haveType && w.written < sniffLen +} + +// writerOnly hides an io.Writer value's optional ReadFrom method +// from io.Copy. +type writerOnly struct { + io.Writer +} + +func srcIsRegularFile(src io.Reader) (isRegular bool, err error) { + switch v := src.(type) { + case *os.File: + fi, err := v.Stat() + if err != nil { + return false, err + } + return fi.Mode().IsRegular(), nil + case *io.LimitedReader: + return srcIsRegularFile(v.R) + default: + return + } +} + +// ReadFrom is here to optimize copying from an *os.File regular file +// to a *net.TCPConn with sendfile. +func (w *response) ReadFrom(src io.Reader) (n int64, err error) { + // Our underlying w.conn.rwc is usually a *TCPConn (with its + // own ReadFrom method). If not, or if our src isn't a regular + // file, just fall back to the normal copy method. + rf, ok := w.conn.rwc.(io.ReaderFrom) + regFile, err := srcIsRegularFile(src) + if err != nil { + return 0, err + } + if !ok || !regFile { + return io.Copy(writerOnly{w}, src) + } + + // sendfile path: + + if !w.wroteHeader { + w.WriteHeader(StatusOK) + } + + if w.needsSniff() { + n0, err := io.Copy(writerOnly{w}, io.LimitReader(src, sniffLen)) + n += n0 + if err != nil { + return n, err + } + } + + w.w.Flush() // get rid of any previous writes + w.cw.flush() // make sure Header is written; flush data to rwc + + // Now that cw has been flushed, its chunking field is guaranteed initialized. + if !w.cw.chunking && w.bodyAllowed() { + n0, err := rf.ReadFrom(src) + n += n0 + w.written += n0 + return n, err + } + + n0, err := io.Copy(writerOnly{w}, src) + n += n0 + return n, err +} + +// noLimit is an effective infinite upper bound for io.LimitedReader +const noLimit int64 = (1 << 63) - 1 + +// debugServerConnections controls whether all server connections are wrapped +// with a verbose logging wrapper. +const debugServerConnections = false + +// Create new connection from rwc. +func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { + c = new(conn) + c.remoteAddr = rwc.RemoteAddr().String() + c.server = srv + c.rwc = rwc + c.w = rwc + if debugServerConnections { + c.rwc = newLoggingConn("server", c.rwc) + } + c.sr = liveSwitchReader{r: c.rwc} + c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader) + br := newBufioReader(c.lr) + bw := newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) + c.buf = bufio.NewReadWriter(br, bw) + return c, nil +} + +var ( + bufioReaderPool sync.Pool + bufioWriter2kPool sync.Pool + bufioWriter4kPool sync.Pool +) + +func bufioWriterPool(size int) *sync.Pool { + switch size { + case 2 << 10: + return &bufioWriter2kPool + case 4 << 10: + return &bufioWriter4kPool + } + return nil +} + +func newBufioReader(r io.Reader) *bufio.Reader { + if v := bufioReaderPool.Get(); v != nil { + br := v.(*bufio.Reader) + br.Reset(r) + return br + } + return bufio.NewReader(r) +} + +func putBufioReader(br *bufio.Reader) { + br.Reset(nil) + bufioReaderPool.Put(br) +} + +func newBufioWriterSize(w io.Writer, size int) *bufio.Writer { + pool := bufioWriterPool(size) + if pool != nil { + if v := pool.Get(); v != nil { + bw := v.(*bufio.Writer) + bw.Reset(w) + return bw + } + } + return bufio.NewWriterSize(w, size) +} + +func putBufioWriter(bw *bufio.Writer) { + bw.Reset(nil) + if pool := bufioWriterPool(bw.Available()); pool != nil { + pool.Put(bw) + } +} + +// DefaultMaxHeaderBytes is the maximum permitted size of the headers +// in an HTTP request. +// This can be overridden by setting Server.MaxHeaderBytes. +const DefaultMaxHeaderBytes = 1 << 20 // 1 MB + +func (srv *Server) maxHeaderBytes() int { + if srv.MaxHeaderBytes > 0 { + return srv.MaxHeaderBytes + } + return DefaultMaxHeaderBytes +} + +func (srv *Server) initialLimitedReaderSize() int64 { + return int64(srv.maxHeaderBytes()) + 4096 // bufio slop +} + +// wrapper around io.ReaderCloser which on first read, sends an +// HTTP/1.1 100 Continue header +type expectContinueReader struct { + resp *response + readCloser io.ReadCloser + closed bool +} + +func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { + if ecr.closed { + return 0, ErrBodyReadAfterClose + } + if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() { + ecr.resp.wroteContinue = true + ecr.resp.conn.buf.WriteString("HTTP/1.1 100 Continue\r\n\r\n") + ecr.resp.conn.buf.Flush() + } + return ecr.readCloser.Read(p) +} + +func (ecr *expectContinueReader) Close() error { + ecr.closed = true + return ecr.readCloser.Close() +} + +// TimeFormat is the time format to use with +// time.Parse and time.Time.Format when parsing +// or generating times in HTTP headers. +// It is like time.RFC1123 but hard codes GMT as the time zone. +const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT" + +// appendTime is a non-allocating version of []byte(t.UTC().Format(TimeFormat)) +func appendTime(b []byte, t time.Time) []byte { + const days = "SunMonTueWedThuFriSat" + const months = "JanFebMarAprMayJunJulAugSepOctNovDec" + + t = t.UTC() + yy, mm, dd := t.Date() + hh, mn, ss := t.Clock() + day := days[3*t.Weekday():] + mon := months[3*(mm-1):] + + return append(b, + day[0], day[1], day[2], ',', ' ', + byte('0'+dd/10), byte('0'+dd%10), ' ', + mon[0], mon[1], mon[2], ' ', + byte('0'+yy/1000), byte('0'+(yy/100)%10), byte('0'+(yy/10)%10), byte('0'+yy%10), ' ', + byte('0'+hh/10), byte('0'+hh%10), ':', + byte('0'+mn/10), byte('0'+mn%10), ':', + byte('0'+ss/10), byte('0'+ss%10), ' ', + 'G', 'M', 'T') +} + +var errTooLarge = errors.New("http: request too large") + +// Read next request from connection. +func (c *conn) readRequest() (w *response, err error) { + if c.hijacked() { + return nil, ErrHijacked + } + + if d := c.server.ReadTimeout; d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + } + if d := c.server.WriteTimeout; d != 0 { + defer func() { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + }() + } + + c.lr.N = c.server.initialLimitedReaderSize() + var req *Request + if req, err = ReadRequest(c.buf.Reader); err != nil { + if c.lr.N == 0 { + return nil, errTooLarge + } + return nil, err + } + c.lr.N = noLimit + + req.RemoteAddr = c.remoteAddr + req.TLS = c.tlsState + + w = &response{ + conn: c, + req: req, + handlerHeader: make(Header), + contentLength: -1, + } + w.cw.res = w + w.w = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize) + return w, nil +} + +func (w *response) Header() Header { + if w.cw.header == nil && w.wroteHeader && !w.cw.wroteHeader { + // Accessing the header between logically writing it + // and physically writing it means we need to allocate + // a clone to snapshot the logically written state. + w.cw.header = w.handlerHeader.clone() + } + w.calledHeader = true + return w.handlerHeader +} + +// maxPostHandlerReadBytes is the max number of Request.Body bytes not +// consumed by a handler that the server will read from the client +// in order to keep a connection alive. If there are more bytes than +// this then the server to be paranoid instead sends a "Connection: +// close" response. +// +// This number is approximately what a typical machine's TCP buffer +// size is anyway. (if we have the bytes on the machine, we might as +// well read them) +const maxPostHandlerReadBytes = 256 << 10 + +func (w *response) WriteHeader(code int) { + if w.conn.hijacked() { + w.conn.server.logf("http: response.WriteHeader on hijacked connection") + return + } + if w.wroteHeader { + w.conn.server.logf("http: multiple response.WriteHeader calls") + return + } + w.wroteHeader = true + w.status = code + + if w.calledHeader && w.cw.header == nil { + w.cw.header = w.handlerHeader.clone() + } + + if cl := w.handlerHeader.get("Content-Length"); cl != "" { + v, err := strconv.ParseInt(cl, 10, 64) + if err == nil && v >= 0 { + w.contentLength = v + } else { + w.conn.server.logf("http: invalid Content-Length of %q", cl) + w.handlerHeader.Del("Content-Length") + } + } +} + +// extraHeader is the set of headers sometimes added by chunkWriter.writeHeader. +// This type is used to avoid extra allocations from cloning and/or populating +// the response Header map and all its 1-element slices. +type extraHeader struct { + contentType string + connection string + transferEncoding string + date []byte // written if not nil + contentLength []byte // written if not nil +} + +// Sorted the same as extraHeader.Write's loop. +var extraHeaderKeys = [][]byte{ + []byte("Content-Type"), + []byte("Connection"), + []byte("Transfer-Encoding"), +} + +var ( + headerContentLength = []byte("Content-Length: ") + headerDate = []byte("Date: ") +) + +// Write writes the headers described in h to w. +// +// This method has a value receiver, despite the somewhat large size +// of h, because it prevents an allocation. The escape analysis isn't +// smart enough to realize this function doesn't mutate h. +func (h extraHeader) Write(w *bufio.Writer) { + if h.date != nil { + w.Write(headerDate) + w.Write(h.date) + w.Write(crlf) + } + if h.contentLength != nil { + w.Write(headerContentLength) + w.Write(h.contentLength) + w.Write(crlf) + } + for i, v := range []string{h.contentType, h.connection, h.transferEncoding} { + if v != "" { + w.Write(extraHeaderKeys[i]) + w.Write(colonSpace) + w.WriteString(v) + w.Write(crlf) + } + } +} + +// writeHeader finalizes the header sent to the client and writes it +// to cw.res.conn.buf. +// +// p is not written by writeHeader, but is the first chunk of the body +// that will be written. It is sniffed for a Content-Type if none is +// set explicitly. It's also used to set the Content-Length, if the +// total body size was small and the handler has already finished +// running. +func (cw *chunkWriter) writeHeader(p []byte) { + if cw.wroteHeader { + return + } + cw.wroteHeader = true + + w := cw.res + keepAlivesEnabled := w.conn.server.doKeepAlives() + isHEAD := w.req.Method == "HEAD" + + // header is written out to w.conn.buf below. Depending on the + // state of the handler, we either own the map or not. If we + // don't own it, the exclude map is created lazily for + // WriteSubset to remove headers. The setHeader struct holds + // headers we need to add. + header := cw.header + owned := header != nil + if !owned { + header = w.handlerHeader + } + var excludeHeader map[string]bool + delHeader := func(key string) { + if owned { + header.Del(key) + return + } + if _, ok := header[key]; !ok { + return + } + if excludeHeader == nil { + excludeHeader = make(map[string]bool) + } + excludeHeader[key] = true + } + var setHeader extraHeader + + // If the handler is done but never sent a Content-Length + // response header and this is our first (and last) write, set + // it, even to zero. This helps HTTP/1.0 clients keep their + // "keep-alive" connections alive. + // Exceptions: 304/204/1xx responses never get Content-Length, and if + // it was a HEAD request, we don't know the difference between + // 0 actual bytes and 0 bytes because the handler noticed it + // was a HEAD request and chose not to write anything. So for + // HEAD, the handler should either write the Content-Length or + // write non-zero bytes. If it's actually 0 bytes and the + // handler never looked at the Request.Method, we just don't + // send a Content-Length header. + if w.handlerDone && bodyAllowedForStatus(w.status) && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) { + w.contentLength = int64(len(p)) + setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10) + } + + // If this was an HTTP/1.0 request with keep-alive and we sent a + // Content-Length back, we can make this a keep-alive response ... + if w.req.wantsHttp10KeepAlive() && keepAlivesEnabled { + sentLength := header.get("Content-Length") != "" + if sentLength && header.get("Connection") == "keep-alive" { + w.closeAfterReply = false + } + } + + // Check for a explicit (and valid) Content-Length header. + hasCL := w.contentLength != -1 + + if w.req.wantsHttp10KeepAlive() && (isHEAD || hasCL) { + _, connectionHeaderSet := header["Connection"] + if !connectionHeaderSet { + setHeader.connection = "keep-alive" + } + } else if !w.req.ProtoAtLeast(1, 1) || w.req.wantsClose() { + w.closeAfterReply = true + } + + if header.get("Connection") == "close" || !keepAlivesEnabled { + w.closeAfterReply = true + } + + // Per RFC 2616, we should consume the request body before + // replying, if the handler hasn't already done so. But we + // don't want to do an unbounded amount of reading here for + // DoS reasons, so we only try up to a threshold. + if w.req.ContentLength != 0 && !w.closeAfterReply { + ecr, isExpecter := w.req.Body.(*expectContinueReader) + if !isExpecter || ecr.resp.wroteContinue { + n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1) + if n >= maxPostHandlerReadBytes { + w.requestTooLarge() + delHeader("Connection") + setHeader.connection = "close" + } else { + w.req.Body.Close() + } + } + } + + code := w.status + if bodyAllowedForStatus(code) { + // If no content type, apply sniffing algorithm to body. + _, haveType := header["Content-Type"] + if !haveType { + setHeader.contentType = DetectContentType(p) + } + } else { + for _, k := range suppressedHeaders(code) { + delHeader(k) + } + } + + if _, ok := header["Date"]; !ok { + setHeader.date = appendTime(cw.res.dateBuf[:0], time.Now()) + } + + te := header.get("Transfer-Encoding") + hasTE := te != "" + if hasCL && hasTE && te != "identity" { + // TODO: return an error if WriteHeader gets a return parameter + // For now just ignore the Content-Length. + w.conn.server.logf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", + te, w.contentLength) + delHeader("Content-Length") + hasCL = false + } + + if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) { + // do nothing + } else if code == StatusNoContent { + delHeader("Transfer-Encoding") + } else if hasCL { + delHeader("Transfer-Encoding") + } else if w.req.ProtoAtLeast(1, 1) { + // HTTP/1.1 or greater: Transfer-Encoding has been set to identity, and no + // content-length has been provided. The connection must be closed after the + // reply is written, and no chunking is to be done. This is the setup + // recommended in the Server-Sent Events candidate recommendation 11, + // section 8. + if hasTE && te == "identity" { + cw.chunking = false + w.closeAfterReply = true + } else { + // HTTP/1.1 or greater: use chunked transfer encoding + // to avoid closing the connection at EOF. + cw.chunking = true + setHeader.transferEncoding = "chunked" + } + } else { + // HTTP version < 1.1: cannot do chunked transfer + // encoding and we don't know the Content-Length so + // signal EOF by closing connection. + w.closeAfterReply = true + delHeader("Transfer-Encoding") // in case already set + } + + // Cannot use Content-Length with non-identity Transfer-Encoding. + if cw.chunking { + delHeader("Content-Length") + } + if !w.req.ProtoAtLeast(1, 0) { + return + } + + if w.closeAfterReply && (!keepAlivesEnabled || !hasToken(cw.header.get("Connection"), "close")) { + delHeader("Connection") + if w.req.ProtoAtLeast(1, 1) { + setHeader.connection = "close" + } + } + + w.conn.buf.WriteString(statusLine(w.req, code)) + cw.header.WriteSubset(w.conn.buf, excludeHeader) + setHeader.Write(w.conn.buf.Writer) + w.conn.buf.Write(crlf) +} + +// statusLines is a cache of Status-Line strings, keyed by code (for +// HTTP/1.1) or negative code (for HTTP/1.0). This is faster than a +// map keyed by struct of two fields. This map's max size is bounded +// by 2*len(statusText), two protocol types for each known official +// status code in the statusText map. +var ( + statusMu sync.RWMutex + statusLines = make(map[int]string) +) + +// statusLine returns a response Status-Line (RFC 2616 Section 6.1) +// for the given request and response status code. +func statusLine(req *Request, code int) string { + // Fast path: + key := code + proto11 := req.ProtoAtLeast(1, 1) + if !proto11 { + key = -key + } + statusMu.RLock() + line, ok := statusLines[key] + statusMu.RUnlock() + if ok { + return line + } + + // Slow path: + proto := "HTTP/1.0" + if proto11 { + proto = "HTTP/1.1" + } + codestring := strconv.Itoa(code) + text, ok := statusText[code] + if !ok { + text = "status code " + codestring + } + line = proto + " " + codestring + " " + text + "\r\n" + if ok { + statusMu.Lock() + defer statusMu.Unlock() + statusLines[key] = line + } + return line +} + +// bodyAllowed returns true if a Write is allowed for this response type. +// It's illegal to call this before the header has been flushed. +func (w *response) bodyAllowed() bool { + if !w.wroteHeader { + panic("") + } + return bodyAllowedForStatus(w.status) +} + +// The Life Of A Write is like this: +// +// Handler starts. No header has been sent. The handler can either +// write a header, or just start writing. Writing before sending a header +// sends an implicitly empty 200 OK header. +// +// If the handler didn't declare a Content-Length up front, we either +// go into chunking mode or, if the handler finishes running before +// the chunking buffer size, we compute a Content-Length and send that +// in the header instead. +// +// Likewise, if the handler didn't set a Content-Type, we sniff that +// from the initial chunk of output. +// +// The Writers are wired together like: +// +// 1. *response (the ResponseWriter) -> +// 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes +// 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type) +// and which writes the chunk headers, if needed. +// 4. conn.buf, a bufio.Writer of default (4kB) bytes, writing to -> +// 5. checkConnErrorWriter{c}, which notes any non-nil error on Write +// and populates c.werr with it if so. but otherwise writes to: +// 6. the rwc, the net.Conn. +// +// TODO(bradfitz): short-circuit some of the buffering when the +// initial header contains both a Content-Type and Content-Length. +// Also short-circuit in (1) when the header's been sent and not in +// chunking mode, writing directly to (4) instead, if (2) has no +// buffered data. More generally, we could short-circuit from (1) to +// (3) even in chunking mode if the write size from (1) is over some +// threshold and nothing is in (2). The answer might be mostly making +// bufferBeforeChunkingSize smaller and having bufio's fast-paths deal +// with this instead. +func (w *response) Write(data []byte) (n int, err error) { + return w.write(len(data), data, "") +} + +func (w *response) WriteString(data string) (n int, err error) { + return w.write(len(data), nil, data) +} + +// either dataB or dataS is non-zero. +func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) { + if w.conn.hijacked() { + w.conn.server.logf("http: response.Write on hijacked connection") + return 0, ErrHijacked + } + if !w.wroteHeader { + w.WriteHeader(StatusOK) + } + if lenData == 0 { + return 0, nil + } + if !w.bodyAllowed() { + return 0, ErrBodyNotAllowed + } + + w.written += int64(lenData) // ignoring errors, for errorKludge + if w.contentLength != -1 && w.written > w.contentLength { + return 0, ErrContentLength + } + if dataB != nil { + return w.w.Write(dataB) + } else { + return w.w.WriteString(dataS) + } +} + +func (w *response) finishRequest() { + w.handlerDone = true + + if !w.wroteHeader { + w.WriteHeader(StatusOK) + } + + w.w.Flush() + putBufioWriter(w.w) + w.cw.close() + w.conn.buf.Flush() + + // Close the body (regardless of w.closeAfterReply) so we can + // re-use its bufio.Reader later safely. + w.req.Body.Close() + + if w.req.MultipartForm != nil { + w.req.MultipartForm.RemoveAll() + } + + if w.req.Method != "HEAD" && w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written { + // Did not write enough. Avoid getting out of sync. + w.closeAfterReply = true + } + + // There was some error writing to the underlying connection + // during the request, so don't re-use this conn. + if w.conn.werr != nil { + w.closeAfterReply = true + } +} + +func (w *response) Flush() { + if !w.wroteHeader { + w.WriteHeader(StatusOK) + } + w.w.Flush() + w.cw.flush() +} + +func (c *conn) finalFlush() { + if c.buf != nil { + c.buf.Flush() + + // Steal the bufio.Reader (~4KB worth of memory) and its associated + // reader for a future connection. + putBufioReader(c.buf.Reader) + + // Steal the bufio.Writer (~4KB worth of memory) and its associated + // writer for a future connection. + putBufioWriter(c.buf.Writer) + + c.buf = nil + } +} + +// Close the connection. +func (c *conn) close() { + c.finalFlush() + if c.rwc != nil { + c.rwc.Close() + c.rwc = nil + } +} + +// rstAvoidanceDelay is the amount of time we sleep after closing the +// write side of a TCP connection before closing the entire socket. +// By sleeping, we increase the chances that the client sees our FIN +// and processes its final data before they process the subsequent RST +// from closing a connection with known unread data. +// This RST seems to occur mostly on BSD systems. (And Windows?) +// This timeout is somewhat arbitrary (~latency around the planet). +const rstAvoidanceDelay = 500 * time.Millisecond + +type closeWriter interface { + CloseWrite() error +} + +var _ closeWriter = (*net.TCPConn)(nil) + +// closeWrite flushes any outstanding data and sends a FIN packet (if +// client is connected via TCP), signalling that we're done. We then +// pause for a bit, hoping the client processes it before any +// subsequent RST. +// +// See http://golang.org/issue/3595 +func (c *conn) closeWriteAndWait() { + c.finalFlush() + if tcp, ok := c.rwc.(closeWriter); ok { + tcp.CloseWrite() + } + time.Sleep(rstAvoidanceDelay) +} + +// validNPN reports whether the proto is not a blacklisted Next +// Protocol Negotiation protocol. Empty and built-in protocol types +// are blacklisted and can't be overridden with alternate +// implementations. +func validNPN(proto string) bool { + switch proto { + case "", "http/1.1", "http/1.0": + return false + } + return true +} + +func (c *conn) setState(nc net.Conn, state ConnState) { + if hook := c.server.ConnState; hook != nil { + hook(nc, state) + } +} + +// Serve a new connection. +func (c *conn) serve() { + origConn := c.rwc // copy it before it's set nil on Close or Hijack + defer func() { + if err := recover(); err != nil { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf) + } + if !c.hijacked() { + c.close() + c.setState(origConn, StateClosed) + } + }() + + if tlsConn, ok := c.rwc.(*tls.Conn); ok { + if d := c.server.ReadTimeout; d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + } + if d := c.server.WriteTimeout; d != 0 { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + } + if err := tlsConn.Handshake(); err != nil { + c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err) + return + } + c.tlsState = new(tls.ConnectionState) + *c.tlsState = tlsConn.ConnectionState() + if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) { + if fn := c.server.TLSNextProto[proto]; fn != nil { + h := initNPNRequest{tlsConn, serverHandler{c.server}} + fn(c.server, tlsConn, h) + } + return + } + } + + for { + w, err := c.readRequest() + if c.lr.N != c.server.initialLimitedReaderSize() { + // If we read any bytes off the wire, we're active. + c.setState(c.rwc, StateActive) + } + if err != nil { + if err == errTooLarge { + // Their HTTP client may or may not be + // able to read this if we're + // responding to them and hanging up + // while they're still writing their + // request. Undefined behavior. + io.WriteString(c.rwc, "HTTP/1.1 413 Request Entity Too Large\r\n\r\n") + c.closeWriteAndWait() + break + } else if err == io.EOF { + break // Don't reply + } else if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + break // Don't reply + } + io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\n\r\n") + break + } + + // Expect 100 Continue support + req := w.req + if req.expectsContinue() { + if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 { + // Wrap the Body reader with one that replies on the connection + req.Body = &expectContinueReader{readCloser: req.Body, resp: w} + } + req.Header.Del("Expect") + } else if req.Header.get("Expect") != "" { + w.sendExpectationFailed() + break + } + + // HTTP cannot have multiple simultaneous active requests.[*] + // Until the server replies to this request, it can't read another, + // so we might as well run the handler in this goroutine. + // [*] Not strictly true: HTTP pipelining. We could let them all process + // in parallel even if their responses need to be serialized. + serverHandler{c.server}.ServeHTTP(w, w.req) + if c.hijacked() { + return + } + w.finishRequest() + if w.closeAfterReply { + if w.requestBodyLimitHit { + c.closeWriteAndWait() + } + break + } + c.setState(c.rwc, StateIdle) + } +} + +func (w *response) sendExpectationFailed() { + // TODO(bradfitz): let ServeHTTP handlers handle + // requests with non-standard expectation[s]? Seems + // theoretical at best, and doesn't fit into the + // current ServeHTTP model anyway. We'd need to + // make the ResponseWriter an optional + // "ExpectReplier" interface or something. + // + // For now we'll just obey RFC 2616 14.20 which says + // "If a server receives a request containing an + // Expect field that includes an expectation- + // extension that it does not support, it MUST + // respond with a 417 (Expectation Failed) status." + w.Header().Set("Connection", "close") + w.WriteHeader(StatusExpectationFailed) + w.finishRequest() +} + +// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter +// and a Hijacker. +func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { + if w.wroteHeader { + w.cw.flush() + } + // Release the bufioWriter that writes to the chunk writer, it is not + // used after a connection has been hijacked. + rwc, buf, err = w.conn.hijack() + if err == nil { + putBufioWriter(w.w) + w.w = nil + } + return rwc, buf, err +} + +func (w *response) CloseNotify() <-chan bool { + return w.conn.closeNotify() +} + +// The HandlerFunc type is an adapter to allow the use of +// ordinary functions as HTTP handlers. If f is a function +// with the appropriate signature, HandlerFunc(f) is a +// Handler object that calls f. +type HandlerFunc func(ResponseWriter, *Request) + +// ServeHTTP calls f(w, r). +func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { + f(w, r) +} + +// Helper handlers + +// Error replies to the request with the specified error message and HTTP code. +// The error message should be plain text. +func Error(w ResponseWriter, error string, code int) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(code) + fmt.Fprintln(w, error) +} + +// NotFound replies to the request with an HTTP 404 not found error. +func NotFound(w ResponseWriter, r *Request) { Error(w, "404 page not found", StatusNotFound) } + +// NotFoundHandler returns a simple request handler +// that replies to each request with a ``404 page not found'' reply. +func NotFoundHandler() Handler { return HandlerFunc(NotFound) } + +// StripPrefix returns a handler that serves HTTP requests +// by removing the given prefix from the request URL's Path +// and invoking the handler h. StripPrefix handles a +// request for a path that doesn't begin with prefix by +// replying with an HTTP 404 not found error. +func StripPrefix(prefix string, h Handler) Handler { + if prefix == "" { + return h + } + return HandlerFunc(func(w ResponseWriter, r *Request) { + if p := strings.TrimPrefix(r.URL.Path, prefix); len(p) < len(r.URL.Path) { + r.URL.Path = p + h.ServeHTTP(w, r) + } else { + NotFound(w, r) + } + }) +} + +// Redirect replies to the request with a redirect to url, +// which may be a path relative to the request path. +func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { + if u, err := url.Parse(urlStr); err == nil { + // If url was relative, make absolute by + // combining with request path. + // The browser would probably do this for us, + // but doing it ourselves is more reliable. + + // NOTE(rsc): RFC 2616 says that the Location + // line must be an absolute URI, like + // "http://www.google.com/redirect/", + // not a path like "/redirect/". + // Unfortunately, we don't know what to + // put in the host name section to get the + // client to connect to us again, so we can't + // know the right absolute URI to send back. + // Because of this problem, no one pays attention + // to the RFC; they all send back just a new path. + // So do we. + oldpath := r.URL.Path + if oldpath == "" { // should not happen, but avoid a crash if it does + oldpath = "/" + } + if u.Scheme == "" { + // no leading http://server + if urlStr == "" || urlStr[0] != '/' { + // make relative path absolute + olddir, _ := path.Split(oldpath) + urlStr = olddir + urlStr + } + + var query string + if i := strings.Index(urlStr, "?"); i != -1 { + urlStr, query = urlStr[:i], urlStr[i:] + } + + // clean up but preserve trailing slash + trailing := strings.HasSuffix(urlStr, "/") + urlStr = path.Clean(urlStr) + if trailing && !strings.HasSuffix(urlStr, "/") { + urlStr += "/" + } + urlStr += query + } + } + + w.Header().Set("Location", urlStr) + w.WriteHeader(code) + + // RFC2616 recommends that a short note "SHOULD" be included in the + // response because older user agents may not understand 301/307. + // Shouldn't send the response for POST or HEAD; that leaves GET. + if r.Method == "GET" { + note := "<a href=\"" + htmlEscape(urlStr) + "\">" + statusText[code] + "</a>.\n" + fmt.Fprintln(w, note) + } +} + +var htmlReplacer = strings.NewReplacer( + "&", "&", + "<", "<", + ">", ">", + // """ is shorter than """. + `"`, """, + // "'" is shorter than "'" and apos was not in HTML until HTML5. + "'", "'", +) + +func htmlEscape(s string) string { + return htmlReplacer.Replace(s) +} + +// Redirect to a fixed URL +type redirectHandler struct { + url string + code int +} + +func (rh *redirectHandler) ServeHTTP(w ResponseWriter, r *Request) { + Redirect(w, r, rh.url, rh.code) +} + +// RedirectHandler returns a request handler that redirects +// each request it receives to the given url using the given +// status code. +func RedirectHandler(url string, code int) Handler { + return &redirectHandler{url, code} +} + +// ServeMux is an HTTP request multiplexer. +// It matches the URL of each incoming request against a list of registered +// patterns and calls the handler for the pattern that +// most closely matches the URL. +// +// Patterns name fixed, rooted paths, like "/favicon.ico", +// or rooted subtrees, like "/images/" (note the trailing slash). +// Longer patterns take precedence over shorter ones, so that +// if there are handlers registered for both "/images/" +// and "/images/thumbnails/", the latter handler will be +// called for paths beginning "/images/thumbnails/" and the +// former will receive requests for any other paths in the +// "/images/" subtree. +// +// Note that since a pattern ending in a slash names a rooted subtree, +// the pattern "/" matches all paths not matched by other registered +// patterns, not just the URL with Path == "/". +// +// Patterns may optionally begin with a host name, restricting matches to +// URLs on that host only. Host-specific patterns take precedence over +// general patterns, so that a handler might register for the two patterns +// "/codesearch" and "codesearch.google.com/" without also taking over +// requests for "http://www.google.com/". +// +// ServeMux also takes care of sanitizing the URL request path, +// redirecting any request containing . or .. elements to an +// equivalent .- and ..-free URL. +type ServeMux struct { + mu sync.RWMutex + m map[string]muxEntry + hosts bool // whether any patterns contain hostnames +} + +type muxEntry struct { + explicit bool + h Handler + pattern string +} + +// NewServeMux allocates and returns a new ServeMux. +func NewServeMux() *ServeMux { return &ServeMux{m: make(map[string]muxEntry)} } + +// DefaultServeMux is the default ServeMux used by Serve. +var DefaultServeMux = NewServeMux() + +// Does path match pattern? +func pathMatch(pattern, path string) bool { + if len(pattern) == 0 { + // should not happen + return false + } + n := len(pattern) + if pattern[n-1] != '/' { + return pattern == path + } + return len(path) >= n && path[0:n] == pattern +} + +// Return the canonical path for p, eliminating . and .. elements. +func cleanPath(p string) string { + if p == "" { + return "/" + } + if p[0] != '/' { + p = "/" + p + } + np := path.Clean(p) + // path.Clean removes trailing slash except for root; + // put the trailing slash back if necessary. + if p[len(p)-1] == '/' && np != "/" { + np += "/" + } + return np +} + +// Find a handler on a handler map given a path string +// Most-specific (longest) pattern wins +func (mux *ServeMux) match(path string) (h Handler, pattern string) { + var n = 0 + for k, v := range mux.m { + if !pathMatch(k, path) { + continue + } + if h == nil || len(k) > n { + n = len(k) + h = v.h + pattern = v.pattern + } + } + return +} + +// Handler returns the handler to use for the given request, +// consulting r.Method, r.Host, and r.URL.Path. It always returns +// a non-nil handler. If the path is not in its canonical form, the +// handler will be an internally-generated handler that redirects +// to the canonical path. +// +// Handler also returns the registered pattern that matches the +// request or, in the case of internally-generated redirects, +// the pattern that will match after following the redirect. +// +// If there is no registered handler that applies to the request, +// Handler returns a ``page not found'' handler and an empty pattern. +func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { + if r.Method != "CONNECT" { + if p := cleanPath(r.URL.Path); p != r.URL.Path { + _, pattern = mux.handler(r.Host, p) + url := *r.URL + url.Path = p + return RedirectHandler(url.String(), StatusMovedPermanently), pattern + } + } + + return mux.handler(r.Host, r.URL.Path) +} + +// handler is the main implementation of Handler. +// The path is known to be in canonical form, except for CONNECT methods. +func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) { + mux.mu.RLock() + defer mux.mu.RUnlock() + + // Host-specific pattern takes precedence over generic ones + if mux.hosts { + h, pattern = mux.match(host + path) + } + if h == nil { + h, pattern = mux.match(path) + } + if h == nil { + h, pattern = NotFoundHandler(), "" + } + return +} + +// ServeHTTP dispatches the request to the handler whose +// pattern most closely matches the request URL. +func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { + if r.RequestURI == "*" { + if r.ProtoAtLeast(1, 1) { + w.Header().Set("Connection", "close") + } + w.WriteHeader(StatusBadRequest) + return + } + h, _ := mux.Handler(r) + h.ServeHTTP(w, r) +} + +// Handle registers the handler for the given pattern. +// If a handler already exists for pattern, Handle panics. +func (mux *ServeMux) Handle(pattern string, handler Handler) { + mux.mu.Lock() + defer mux.mu.Unlock() + + if pattern == "" { + panic("http: invalid pattern " + pattern) + } + if handler == nil { + panic("http: nil handler") + } + if mux.m[pattern].explicit { + panic("http: multiple registrations for " + pattern) + } + + mux.m[pattern] = muxEntry{explicit: true, h: handler, pattern: pattern} + + if pattern[0] != '/' { + mux.hosts = true + } + + // Helpful behavior: + // If pattern is /tree/, insert an implicit permanent redirect for /tree. + // It can be overridden by an explicit registration. + n := len(pattern) + if n > 0 && pattern[n-1] == '/' && !mux.m[pattern[0:n-1]].explicit { + // If pattern contains a host name, strip it and use remaining + // path for redirect. + path := pattern + if pattern[0] != '/' { + // In pattern, at least the last character is a '/', so + // strings.Index can't be -1. + path = pattern[strings.Index(pattern, "/"):] + } + mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(path, StatusMovedPermanently), pattern: pattern} + } +} + +// HandleFunc registers the handler function for the given pattern. +func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { + mux.Handle(pattern, HandlerFunc(handler)) +} + +// Handle registers the handler for the given pattern +// in the DefaultServeMux. +// The documentation for ServeMux explains how patterns are matched. +func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) } + +// HandleFunc registers the handler function for the given pattern +// in the DefaultServeMux. +// The documentation for ServeMux explains how patterns are matched. +func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { + DefaultServeMux.HandleFunc(pattern, handler) +} + +// Serve accepts incoming HTTP connections on the listener l, +// creating a new service goroutine for each. The service goroutines +// read requests and then call handler to reply to them. +// Handler is typically nil, in which case the DefaultServeMux is used. +func Serve(l net.Listener, handler Handler) error { + srv := &Server{Handler: handler} + return srv.Serve(l) +} + +// A Server defines parameters for running an HTTP server. +// The zero value for Server is a valid configuration. +type Server struct { + Addr string // TCP address to listen on, ":http" if empty + Handler Handler // handler to invoke, http.DefaultServeMux if nil + ReadTimeout time.Duration // maximum duration before timing out read of the request + WriteTimeout time.Duration // maximum duration before timing out write of the response + MaxHeaderBytes int // maximum size of request headers, DefaultMaxHeaderBytes if 0 + TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + + // TLSNextProto optionally specifies a function to take over + // ownership of the provided TLS connection when an NPN + // protocol upgrade has occurred. The map key is the protocol + // name negotiated. The Handler argument should be used to + // handle HTTP requests and will initialize the Request's TLS + // and RemoteAddr if not already set. The connection is + // automatically closed when the function returns. + TLSNextProto map[string]func(*Server, *tls.Conn, Handler) + + // ConnState specifies an optional callback function that is + // called when a client connection changes state. See the + // ConnState type and associated constants for details. + ConnState func(net.Conn, ConnState) + + // ErrorLog specifies an optional logger for errors accepting + // connections and unexpected behavior from handlers. + // If nil, logging goes to os.Stderr via the log package's + // standard logger. + ErrorLog *log.Logger + + disableKeepAlives int32 // accessed atomically. +} + +// A ConnState represents the state of a client connection to a server. +// It's used by the optional Server.ConnState hook. +type ConnState int + +const ( + // StateNew represents a new connection that is expected to + // send a request immediately. Connections begin at this + // state and then transition to either StateActive or + // StateClosed. + StateNew ConnState = iota + + // StateActive represents a connection that has read 1 or more + // bytes of a request. The Server.ConnState hook for + // StateActive fires before the request has entered a handler + // and doesn't fire again until the request has been + // handled. After the request is handled, the state + // transitions to StateClosed, StateHijacked, or StateIdle. + StateActive + + // StateIdle represents a connection that has finished + // handling a request and is in the keep-alive state, waiting + // for a new request. Connections transition from StateIdle + // to either StateActive or StateClosed. + StateIdle + + // StateHijacked represents a hijacked connection. + // This is a terminal state. It does not transition to StateClosed. + StateHijacked + + // StateClosed represents a closed connection. + // This is a terminal state. Hijacked connections do not + // transition to StateClosed. + StateClosed +) + +var stateName = map[ConnState]string{ + StateNew: "new", + StateActive: "active", + StateIdle: "idle", + StateHijacked: "hijacked", + StateClosed: "closed", +} + +func (c ConnState) String() string { + return stateName[c] +} + +// serverHandler delegates to either the server's Handler or +// DefaultServeMux and also handles "OPTIONS *" requests. +type serverHandler struct { + srv *Server +} + +func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) { + handler := sh.srv.Handler + if handler == nil { + handler = DefaultServeMux + } + if req.RequestURI == "*" && req.Method == "OPTIONS" { + handler = globalOptionsHandler{} + } + handler.ServeHTTP(rw, req) +} + +// ListenAndServe listens on the TCP network address srv.Addr and then +// calls Serve to handle requests on incoming connections. If +// srv.Addr is blank, ":http" is used. +func (srv *Server) ListenAndServe() error { + addr := srv.Addr + if addr == "" { + addr = ":http" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + return srv.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)}) +} + +// Serve accepts incoming connections on the Listener l, creating a +// new service goroutine for each. The service goroutines read requests and +// then call srv.Handler to reply to them. +func (srv *Server) Serve(l net.Listener) error { + defer l.Close() + var tempDelay time.Duration // how long to sleep on accept failure + for { + rw, e := l.Accept() + if e != nil { + if ne, ok := e.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + srv.logf("http: Accept error: %v; retrying in %v", e, tempDelay) + time.Sleep(tempDelay) + continue + } + return e + } + tempDelay = 0 + c, err := srv.newConn(rw) + if err != nil { + continue + } + c.setState(c.rwc, StateNew) // before Serve can return + go c.serve() + } +} + +func (s *Server) doKeepAlives() bool { + return atomic.LoadInt32(&s.disableKeepAlives) == 0 +} + +// SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled. +// By default, keep-alives are always enabled. Only very +// resource-constrained environments or servers in the process of +// shutting down should disable them. +func (s *Server) SetKeepAlivesEnabled(v bool) { + if v { + atomic.StoreInt32(&s.disableKeepAlives, 0) + } else { + atomic.StoreInt32(&s.disableKeepAlives, 1) + } +} + +func (s *Server) logf(format string, args ...interface{}) { + if s.ErrorLog != nil { + s.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +// ListenAndServe listens on the TCP network address addr +// and then calls Serve with handler to handle requests +// on incoming connections. Handler is typically nil, +// in which case the DefaultServeMux is used. +// +// A trivial example server is: +// +// package main +// +// import ( +// "io" +// "net/http" +// "log" +// ) +// +// // hello world, the web server +// func HelloServer(w http.ResponseWriter, req *http.Request) { +// io.WriteString(w, "hello, world!\n") +// } +// +// func main() { +// http.HandleFunc("/hello", HelloServer) +// err := http.ListenAndServe(":12345", nil) +// if err != nil { +// log.Fatal("ListenAndServe: ", err) +// } +// } +func ListenAndServe(addr string, handler Handler) error { + server := &Server{Addr: addr, Handler: handler} + return server.ListenAndServe() +} + +// ListenAndServeTLS acts identically to ListenAndServe, except that it +// expects HTTPS connections. Additionally, files containing a certificate and +// matching private key for the server must be provided. If the certificate +// is signed by a certificate authority, the certFile should be the concatenation +// of the server's certificate followed by the CA's certificate. +// +// A trivial example server is: +// +// import ( +// "log" +// "net/http" +// ) +// +// func handler(w http.ResponseWriter, req *http.Request) { +// w.Header().Set("Content-Type", "text/plain") +// w.Write([]byte("This is an example server.\n")) +// } +// +// func main() { +// http.HandleFunc("/", handler) +// log.Printf("About to listen on 10443. Go to https://127.0.0.1:10443/") +// err := http.ListenAndServeTLS(":10443", "cert.pem", "key.pem", nil) +// if err != nil { +// log.Fatal(err) +// } +// } +// +// One can use generate_cert.go in crypto/tls to generate cert.pem and key.pem. +func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Handler) error { + server := &Server{Addr: addr, Handler: handler} + return server.ListenAndServeTLS(certFile, keyFile) +} + +// ListenAndServeTLS listens on the TCP network address srv.Addr and +// then calls Serve to handle requests on incoming TLS connections. +// +// Filenames containing a certificate and matching private key for +// the server must be provided. If the certificate is signed by a +// certificate authority, the certFile should be the concatenation +// of the server's certificate followed by the CA's certificate. +// +// If srv.Addr is blank, ":https" is used. +func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + config := &tls.Config{} + if srv.TLSConfig != nil { + *config = *srv.TLSConfig + } + if config.NextProtos == nil { + config.NextProtos = []string{"http/1.1"} + } + + var err error + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) + return srv.Serve(tlsListener) +} + +// TimeoutHandler returns a Handler that runs h with the given time limit. +// +// The new Handler calls h.ServeHTTP to handle each request, but if a +// call runs for longer than its time limit, the handler responds with +// a 503 Service Unavailable error and the given message in its body. +// (If msg is empty, a suitable default message will be sent.) +// After such a timeout, writes by h to its ResponseWriter will return +// ErrHandlerTimeout. +func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler { + f := func() <-chan time.Time { + return time.After(dt) + } + return &timeoutHandler{h, f, msg} +} + +// ErrHandlerTimeout is returned on ResponseWriter Write calls +// in handlers which have timed out. +var ErrHandlerTimeout = errors.New("http: Handler timeout") + +type timeoutHandler struct { + handler Handler + timeout func() <-chan time.Time // returns channel producing a timeout + body string +} + +func (h *timeoutHandler) errorBody() string { + if h.body != "" { + return h.body + } + return "<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>" +} + +func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { + done := make(chan bool, 1) + tw := &timeoutWriter{w: w} + go func() { + h.handler.ServeHTTP(tw, r) + done <- true + }() + select { + case <-done: + return + case <-h.timeout(): + tw.mu.Lock() + defer tw.mu.Unlock() + if !tw.wroteHeader { + tw.w.WriteHeader(StatusServiceUnavailable) + tw.w.Write([]byte(h.errorBody())) + } + tw.timedOut = true + } +} + +type timeoutWriter struct { + w ResponseWriter + + mu sync.Mutex + timedOut bool + wroteHeader bool +} + +func (tw *timeoutWriter) Header() Header { + return tw.w.Header() +} + +func (tw *timeoutWriter) Write(p []byte) (int, error) { + tw.mu.Lock() + defer tw.mu.Unlock() + tw.wroteHeader = true // implicitly at least + if tw.timedOut { + return 0, ErrHandlerTimeout + } + return tw.w.Write(p) +} + +func (tw *timeoutWriter) WriteHeader(code int) { + tw.mu.Lock() + defer tw.mu.Unlock() + if tw.timedOut || tw.wroteHeader { + return + } + tw.wroteHeader = true + tw.w.WriteHeader(code) +} + +// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by ListenAndServe and ListenAndServeTLS so +// dead TCP connections (e.g. closing laptop mid-download) eventually +// go away. +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + +// globalOptionsHandler responds to "OPTIONS *" requests. +type globalOptionsHandler struct{} + +func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "0") + if r.ContentLength != 0 { + // Read up to 4KB of OPTIONS body (as mentioned in the + // spec as being reserved for future use), but anything + // over that is considered a waste of server resources + // (or an attack) and we abort and close the connection, + // courtesy of MaxBytesReader's EOF behavior. + mb := MaxBytesReader(w, r.Body, 4<<10) + io.Copy(ioutil.Discard, mb) + } +} + +type eofReaderWithWriteTo struct{} + +func (eofReaderWithWriteTo) WriteTo(io.Writer) (int64, error) { return 0, nil } +func (eofReaderWithWriteTo) Read([]byte) (int, error) { return 0, io.EOF } + +// eofReader is a non-nil io.ReadCloser that always returns EOF. +// It has a WriteTo method so io.Copy won't need a buffer. +var eofReader = &struct { + eofReaderWithWriteTo + io.Closer +}{ + eofReaderWithWriteTo{}, + ioutil.NopCloser(nil), +} + +// Verify that an io.Copy from an eofReader won't require a buffer. +var _ io.WriterTo = eofReader + +// initNPNRequest is an HTTP handler that initializes certain +// uninitialized fields in its *Request. Such partially-initialized +// Requests come from NPN protocol handlers. +type initNPNRequest struct { + c *tls.Conn + h serverHandler +} + +func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) { + if req.TLS == nil { + req.TLS = &tls.ConnectionState{} + *req.TLS = h.c.ConnectionState() + } + if req.Body == nil { + req.Body = eofReader + } + if req.RemoteAddr == "" { + req.RemoteAddr = h.c.RemoteAddr().String() + } + h.h.ServeHTTP(rw, req) +} + +// loggingConn is used for debugging. +type loggingConn struct { + name string + net.Conn +} + +var ( + uniqNameMu sync.Mutex + uniqNameNext = make(map[string]int) +) + +func newLoggingConn(baseName string, c net.Conn) net.Conn { + uniqNameMu.Lock() + defer uniqNameMu.Unlock() + uniqNameNext[baseName]++ + return &loggingConn{ + name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]), + Conn: c, + } +} + +func (c *loggingConn) Write(p []byte) (n int, err error) { + log.Printf("%s.Write(%d) = ....", c.name, len(p)) + n, err = c.Conn.Write(p) + log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err) + return +} + +func (c *loggingConn) Read(p []byte) (n int, err error) { + log.Printf("%s.Read(%d) = ....", c.name, len(p)) + n, err = c.Conn.Read(p) + log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err) + return +} + +func (c *loggingConn) Close() (err error) { + log.Printf("%s.Close() = ...", c.name) + err = c.Conn.Close() + log.Printf("%s.Close() = %v", c.name, err) + return +} + +// checkConnErrorWriter writes to c.rwc and records any write errors to c.werr. +// It only contains one field (and a pointer field at that), so it +// fits in an interface value without an extra allocation. +type checkConnErrorWriter struct { + c *conn +} + +func (w checkConnErrorWriter) Write(p []byte) (n int, err error) { + n, err = w.c.w.Write(p) // c.w == c.rwc, except after a hijack, when rwc is nil. + if err != nil && w.c.werr == nil { + w.c.werr = err + } + return +} diff --git a/src/net/http/sniff.go b/src/net/http/sniff.go new file mode 100644 index 000000000..68f519b05 --- /dev/null +++ b/src/net/http/sniff.go @@ -0,0 +1,214 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "encoding/binary" +) + +// The algorithm uses at most sniffLen bytes to make its decision. +const sniffLen = 512 + +// DetectContentType implements the algorithm described +// at http://mimesniff.spec.whatwg.org/ to determine the +// Content-Type of the given data. It considers at most the +// first 512 bytes of data. DetectContentType always returns +// a valid MIME type: if it cannot determine a more specific one, it +// returns "application/octet-stream". +func DetectContentType(data []byte) string { + if len(data) > sniffLen { + data = data[:sniffLen] + } + + // Index of the first non-whitespace byte in data. + firstNonWS := 0 + for ; firstNonWS < len(data) && isWS(data[firstNonWS]); firstNonWS++ { + } + + for _, sig := range sniffSignatures { + if ct := sig.match(data, firstNonWS); ct != "" { + return ct + } + } + + return "application/octet-stream" // fallback +} + +func isWS(b byte) bool { + return bytes.IndexByte([]byte("\t\n\x0C\r "), b) != -1 +} + +type sniffSig interface { + // match returns the MIME type of the data, or "" if unknown. + match(data []byte, firstNonWS int) string +} + +// Data matching the table in section 6. +var sniffSignatures = []sniffSig{ + htmlSig("<!DOCTYPE HTML"), + htmlSig("<HTML"), + htmlSig("<HEAD"), + htmlSig("<SCRIPT"), + htmlSig("<IFRAME"), + htmlSig("<H1"), + htmlSig("<DIV"), + htmlSig("<FONT"), + htmlSig("<TABLE"), + htmlSig("<A"), + htmlSig("<STYLE"), + htmlSig("<TITLE"), + htmlSig("<B"), + htmlSig("<BODY"), + htmlSig("<BR"), + htmlSig("<P"), + htmlSig("<!--"), + + &maskedSig{mask: []byte("\xFF\xFF\xFF\xFF\xFF"), pat: []byte("<?xml"), skipWS: true, ct: "text/xml; charset=utf-8"}, + + &exactSig{[]byte("%PDF-"), "application/pdf"}, + &exactSig{[]byte("%!PS-Adobe-"), "application/postscript"}, + + // UTF BOMs. + &maskedSig{mask: []byte("\xFF\xFF\x00\x00"), pat: []byte("\xFE\xFF\x00\x00"), ct: "text/plain; charset=utf-16be"}, + &maskedSig{mask: []byte("\xFF\xFF\x00\x00"), pat: []byte("\xFF\xFE\x00\x00"), ct: "text/plain; charset=utf-16le"}, + &maskedSig{mask: []byte("\xFF\xFF\xFF\x00"), pat: []byte("\xEF\xBB\xBF\x00"), ct: "text/plain; charset=utf-8"}, + + &exactSig{[]byte("GIF87a"), "image/gif"}, + &exactSig{[]byte("GIF89a"), "image/gif"}, + &exactSig{[]byte("\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"), "image/png"}, + &exactSig{[]byte("\xFF\xD8\xFF"), "image/jpeg"}, + &exactSig{[]byte("BM"), "image/bmp"}, + &maskedSig{ + mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF"), + pat: []byte("RIFF\x00\x00\x00\x00WEBPVP"), + ct: "image/webp", + }, + &exactSig{[]byte("\x00\x00\x01\x00"), "image/vnd.microsoft.icon"}, + &exactSig{[]byte("\x4F\x67\x67\x53\x00"), "application/ogg"}, + &maskedSig{ + mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"), + pat: []byte("RIFF\x00\x00\x00\x00WAVE"), + ct: "audio/wave", + }, + &exactSig{[]byte("\x1A\x45\xDF\xA3"), "video/webm"}, + &exactSig{[]byte("\x52\x61\x72\x20\x1A\x07\x00"), "application/x-rar-compressed"}, + &exactSig{[]byte("\x50\x4B\x03\x04"), "application/zip"}, + &exactSig{[]byte("\x1F\x8B\x08"), "application/x-gzip"}, + + // TODO(dsymonds): Re-enable this when the spec is sorted w.r.t. MP4. + //mp4Sig(0), + + textSig(0), // should be last +} + +type exactSig struct { + sig []byte + ct string +} + +func (e *exactSig) match(data []byte, firstNonWS int) string { + if bytes.HasPrefix(data, e.sig) { + return e.ct + } + return "" +} + +type maskedSig struct { + mask, pat []byte + skipWS bool + ct string +} + +func (m *maskedSig) match(data []byte, firstNonWS int) string { + if m.skipWS { + data = data[firstNonWS:] + } + if len(data) < len(m.mask) { + return "" + } + for i, mask := range m.mask { + db := data[i] & mask + if db != m.pat[i] { + return "" + } + } + return m.ct +} + +type htmlSig []byte + +func (h htmlSig) match(data []byte, firstNonWS int) string { + data = data[firstNonWS:] + if len(data) < len(h)+1 { + return "" + } + for i, b := range h { + db := data[i] + if 'A' <= b && b <= 'Z' { + db &= 0xDF + } + if b != db { + return "" + } + } + // Next byte must be space or right angle bracket. + if db := data[len(h)]; db != ' ' && db != '>' { + return "" + } + return "text/html; charset=utf-8" +} + +type mp4Sig int + +func (mp4Sig) match(data []byte, firstNonWS int) string { + // c.f. section 6.1. + if len(data) < 8 { + return "" + } + boxSize := int(binary.BigEndian.Uint32(data[:4])) + if boxSize%4 != 0 || len(data) < boxSize { + return "" + } + if !bytes.Equal(data[4:8], []byte("ftyp")) { + return "" + } + for st := 8; st < boxSize; st += 4 { + if st == 12 { + // minor version number + continue + } + seg := string(data[st : st+3]) + switch seg { + case "mp4", "iso", "M4V", "M4P", "M4B": + return "video/mp4" + /* The remainder are not in the spec. + case "M4A": + return "audio/mp4" + case "3gp": + return "video/3gpp" + case "jp2": + return "image/jp2" // JPEG 2000 + */ + } + } + return "" +} + +type textSig int + +func (textSig) match(data []byte, firstNonWS int) string { + // c.f. section 5, step 4. + for _, b := range data[firstNonWS:] { + switch { + case 0x00 <= b && b <= 0x08, + b == 0x0B, + 0x0E <= b && b <= 0x1A, + 0x1C <= b && b <= 0x1F: + return "" + } + } + return "text/plain; charset=utf-8" +} diff --git a/src/net/http/sniff_test.go b/src/net/http/sniff_test.go new file mode 100644 index 000000000..24ca27afc --- /dev/null +++ b/src/net/http/sniff_test.go @@ -0,0 +1,171 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "log" + . "net/http" + "net/http/httptest" + "reflect" + "strconv" + "strings" + "testing" +) + +var sniffTests = []struct { + desc string + data []byte + contentType string +}{ + // Some nonsense. + {"Empty", []byte{}, "text/plain; charset=utf-8"}, + {"Binary", []byte{1, 2, 3}, "application/octet-stream"}, + + {"HTML document #1", []byte(`<HtMl><bOdY>blah blah blah</body></html>`), "text/html; charset=utf-8"}, + {"HTML document #2", []byte(`<HTML></HTML>`), "text/html; charset=utf-8"}, + {"HTML document #3 (leading whitespace)", []byte(` <!DOCTYPE HTML>...`), "text/html; charset=utf-8"}, + {"HTML document #4 (leading CRLF)", []byte("\r\n<html>..."), "text/html; charset=utf-8"}, + + {"Plain text", []byte(`This is not HTML. It has โ though.`), "text/plain; charset=utf-8"}, + + {"XML", []byte("\n<?xml!"), "text/xml; charset=utf-8"}, + + // Image types. + {"GIF 87a", []byte(`GIF87a`), "image/gif"}, + {"GIF 89a", []byte(`GIF89a...`), "image/gif"}, + + // TODO(dsymonds): Re-enable this when the spec is sorted w.r.t. MP4. + //{"MP4 video", []byte("\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00mp42isom<\x06t\xbfmdat"), "video/mp4"}, + //{"MP4 audio", []byte("\x00\x00\x00\x20ftypM4A \x00\x00\x00\x00M4A mp42isom\x00\x00\x00\x00"), "audio/mp4"}, +} + +func TestDetectContentType(t *testing.T) { + for _, tt := range sniffTests { + ct := DetectContentType(tt.data) + if ct != tt.contentType { + t.Errorf("%v: DetectContentType = %q, want %q", tt.desc, ct, tt.contentType) + } + } +} + +func TestServerContentType(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + i, _ := strconv.Atoi(r.FormValue("i")) + tt := sniffTests[i] + n, err := w.Write(tt.data) + if n != len(tt.data) || err != nil { + log.Fatalf("%v: Write(%q) = %v, %v want %d, nil", tt.desc, tt.data, n, err, len(tt.data)) + } + })) + defer ts.Close() + + for i, tt := range sniffTests { + resp, err := Get(ts.URL + "/?i=" + strconv.Itoa(i)) + if err != nil { + t.Errorf("%v: %v", tt.desc, err) + continue + } + if ct := resp.Header.Get("Content-Type"); ct != tt.contentType { + t.Errorf("%v: Content-Type = %q, want %q", tt.desc, ct, tt.contentType) + } + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("%v: reading body: %v", tt.desc, err) + } else if !bytes.Equal(data, tt.data) { + t.Errorf("%v: data is %q, want %q", tt.desc, data, tt.data) + } + resp.Body.Close() + } +} + +// Issue 5953: shouldn't sniff if the handler set a Content-Type header, +// even if it's the empty string. +func TestServerIssue5953(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header()["Content-Type"] = []string{""} + fmt.Fprintf(w, "<html><head></head><body>hi</body></html>") + })) + defer ts.Close() + + resp, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + got := resp.Header["Content-Type"] + want := []string{""} + if !reflect.DeepEqual(got, want) { + t.Errorf("Content-Type = %q; want %q", got, want) + } + resp.Body.Close() +} + +func TestContentTypeWithCopy(t *testing.T) { + defer afterTest(t) + + const ( + input = "\n<html>\n\t<head>\n" + expected = "text/html; charset=utf-8" + ) + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + // Use io.Copy from a bytes.Buffer to trigger ReadFrom. + buf := bytes.NewBuffer([]byte(input)) + n, err := io.Copy(w, buf) + if int(n) != len(input) || err != nil { + t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input)) + } + })) + defer ts.Close() + + resp, err := Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + if ct := resp.Header.Get("Content-Type"); ct != expected { + t.Errorf("Content-Type = %q, want %q", ct, expected) + } + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("reading body: %v", err) + } else if !bytes.Equal(data, []byte(input)) { + t.Errorf("data is %q, want %q", data, input) + } + resp.Body.Close() +} + +func TestSniffWriteSize(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + size, _ := strconv.Atoi(r.FormValue("size")) + written, err := io.WriteString(w, strings.Repeat("a", size)) + if err != nil { + t.Errorf("write of %d bytes: %v", size, err) + return + } + if written != size { + t.Errorf("write of %d bytes wrote %d bytes", size, written) + } + })) + defer ts.Close() + for _, size := range []int{0, 1, 200, 600, 999, 1000, 1023, 1024, 512 << 10, 1 << 20} { + res, err := Get(fmt.Sprintf("%s/?size=%d", ts.URL, size)) + if err != nil { + t.Fatalf("size %d: %v", size, err) + } + if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + t.Fatalf("size %d: io.Copy of body = %v", size, err) + } + if err := res.Body.Close(); err != nil { + t.Fatalf("size %d: body Close = %v", size, err) + } + } +} diff --git a/src/net/http/status.go b/src/net/http/status.go new file mode 100644 index 000000000..d253bd5cb --- /dev/null +++ b/src/net/http/status.go @@ -0,0 +1,120 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +// HTTP status codes, defined in RFC 2616. +const ( + StatusContinue = 100 + StatusSwitchingProtocols = 101 + + StatusOK = 200 + StatusCreated = 201 + StatusAccepted = 202 + StatusNonAuthoritativeInfo = 203 + StatusNoContent = 204 + StatusResetContent = 205 + StatusPartialContent = 206 + + StatusMultipleChoices = 300 + StatusMovedPermanently = 301 + StatusFound = 302 + StatusSeeOther = 303 + StatusNotModified = 304 + StatusUseProxy = 305 + StatusTemporaryRedirect = 307 + + StatusBadRequest = 400 + StatusUnauthorized = 401 + StatusPaymentRequired = 402 + StatusForbidden = 403 + StatusNotFound = 404 + StatusMethodNotAllowed = 405 + StatusNotAcceptable = 406 + StatusProxyAuthRequired = 407 + StatusRequestTimeout = 408 + StatusConflict = 409 + StatusGone = 410 + StatusLengthRequired = 411 + StatusPreconditionFailed = 412 + StatusRequestEntityTooLarge = 413 + StatusRequestURITooLong = 414 + StatusUnsupportedMediaType = 415 + StatusRequestedRangeNotSatisfiable = 416 + StatusExpectationFailed = 417 + StatusTeapot = 418 + + StatusInternalServerError = 500 + StatusNotImplemented = 501 + StatusBadGateway = 502 + StatusServiceUnavailable = 503 + StatusGatewayTimeout = 504 + StatusHTTPVersionNotSupported = 505 + + // New HTTP status codes from RFC 6585. Not exported yet in Go 1.1. + // See discussion at https://codereview.appspot.com/7678043/ + statusPreconditionRequired = 428 + statusTooManyRequests = 429 + statusRequestHeaderFieldsTooLarge = 431 + statusNetworkAuthenticationRequired = 511 +) + +var statusText = map[int]string{ + StatusContinue: "Continue", + StatusSwitchingProtocols: "Switching Protocols", + + StatusOK: "OK", + StatusCreated: "Created", + StatusAccepted: "Accepted", + StatusNonAuthoritativeInfo: "Non-Authoritative Information", + StatusNoContent: "No Content", + StatusResetContent: "Reset Content", + StatusPartialContent: "Partial Content", + + StatusMultipleChoices: "Multiple Choices", + StatusMovedPermanently: "Moved Permanently", + StatusFound: "Found", + StatusSeeOther: "See Other", + StatusNotModified: "Not Modified", + StatusUseProxy: "Use Proxy", + StatusTemporaryRedirect: "Temporary Redirect", + + StatusBadRequest: "Bad Request", + StatusUnauthorized: "Unauthorized", + StatusPaymentRequired: "Payment Required", + StatusForbidden: "Forbidden", + StatusNotFound: "Not Found", + StatusMethodNotAllowed: "Method Not Allowed", + StatusNotAcceptable: "Not Acceptable", + StatusProxyAuthRequired: "Proxy Authentication Required", + StatusRequestTimeout: "Request Timeout", + StatusConflict: "Conflict", + StatusGone: "Gone", + StatusLengthRequired: "Length Required", + StatusPreconditionFailed: "Precondition Failed", + StatusRequestEntityTooLarge: "Request Entity Too Large", + StatusRequestURITooLong: "Request URI Too Long", + StatusUnsupportedMediaType: "Unsupported Media Type", + StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable", + StatusExpectationFailed: "Expectation Failed", + StatusTeapot: "I'm a teapot", + + StatusInternalServerError: "Internal Server Error", + StatusNotImplemented: "Not Implemented", + StatusBadGateway: "Bad Gateway", + StatusServiceUnavailable: "Service Unavailable", + StatusGatewayTimeout: "Gateway Timeout", + StatusHTTPVersionNotSupported: "HTTP Version Not Supported", + + statusPreconditionRequired: "Precondition Required", + statusTooManyRequests: "Too Many Requests", + statusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large", + statusNetworkAuthenticationRequired: "Network Authentication Required", +} + +// StatusText returns a text for the HTTP status code. It returns the empty +// string if the code is unknown. +func StatusText(code int) string { + return statusText[code] +} diff --git a/src/net/http/testdata/file b/src/net/http/testdata/file new file mode 100644 index 000000000..11f11f9be --- /dev/null +++ b/src/net/http/testdata/file @@ -0,0 +1 @@ +0123456789 diff --git a/src/net/http/testdata/index.html b/src/net/http/testdata/index.html new file mode 100644 index 000000000..da8e1e93d --- /dev/null +++ b/src/net/http/testdata/index.html @@ -0,0 +1 @@ +index.html says hello diff --git a/src/net/http/testdata/style.css b/src/net/http/testdata/style.css new file mode 100644 index 000000000..208d16d42 --- /dev/null +++ b/src/net/http/testdata/style.css @@ -0,0 +1 @@ +body {} diff --git a/src/net/http/transfer.go b/src/net/http/transfer.go new file mode 100644 index 000000000..520500330 --- /dev/null +++ b/src/net/http/transfer.go @@ -0,0 +1,737 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http/internal" + "net/textproto" + "sort" + "strconv" + "strings" + "sync" +) + +// ErrLineTooLong is returned when reading request or response bodies +// with malformed chunked encoding. +var ErrLineTooLong = internal.ErrLineTooLong + +type errorReader struct { + err error +} + +func (r *errorReader) Read(p []byte) (n int, err error) { + return 0, r.err +} + +// transferWriter inspects the fields of a user-supplied Request or Response, +// sanitizes them without changing the user object and provides methods for +// writing the respective header, body and trailer in wire format. +type transferWriter struct { + Method string + Body io.Reader + BodyCloser io.Closer + ResponseToHEAD bool + ContentLength int64 // -1 means unknown, 0 means exactly none + Close bool + TransferEncoding []string + Trailer Header +} + +func newTransferWriter(r interface{}) (t *transferWriter, err error) { + t = &transferWriter{} + + // Extract relevant fields + atLeastHTTP11 := false + switch rr := r.(type) { + case *Request: + if rr.ContentLength != 0 && rr.Body == nil { + return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength) + } + t.Method = rr.Method + t.Body = rr.Body + t.BodyCloser = rr.Body + t.ContentLength = rr.ContentLength + t.Close = rr.Close + t.TransferEncoding = rr.TransferEncoding + t.Trailer = rr.Trailer + atLeastHTTP11 = rr.ProtoAtLeast(1, 1) + if t.Body != nil && len(t.TransferEncoding) == 0 && atLeastHTTP11 { + if t.ContentLength == 0 { + // Test to see if it's actually zero or just unset. + var buf [1]byte + n, rerr := io.ReadFull(t.Body, buf[:]) + if rerr != nil && rerr != io.EOF { + t.ContentLength = -1 + t.Body = &errorReader{rerr} + } else if n == 1 { + // Oh, guess there is data in this Body Reader after all. + // The ContentLength field just wasn't set. + // Stich the Body back together again, re-attaching our + // consumed byte. + t.ContentLength = -1 + t.Body = io.MultiReader(bytes.NewReader(buf[:]), t.Body) + } else { + // Body is actually empty. + t.Body = nil + t.BodyCloser = nil + } + } + if t.ContentLength < 0 { + t.TransferEncoding = []string{"chunked"} + } + } + case *Response: + if rr.Request != nil { + t.Method = rr.Request.Method + } + t.Body = rr.Body + t.BodyCloser = rr.Body + t.ContentLength = rr.ContentLength + t.Close = rr.Close + t.TransferEncoding = rr.TransferEncoding + t.Trailer = rr.Trailer + atLeastHTTP11 = rr.ProtoAtLeast(1, 1) + t.ResponseToHEAD = noBodyExpected(t.Method) + } + + // Sanitize Body,ContentLength,TransferEncoding + if t.ResponseToHEAD { + t.Body = nil + if chunked(t.TransferEncoding) { + t.ContentLength = -1 + } + } else { + if !atLeastHTTP11 || t.Body == nil { + t.TransferEncoding = nil + } + if chunked(t.TransferEncoding) { + t.ContentLength = -1 + } else if t.Body == nil { // no chunking, no body + t.ContentLength = 0 + } + } + + // Sanitize Trailer + if !chunked(t.TransferEncoding) { + t.Trailer = nil + } + + return t, nil +} + +func noBodyExpected(requestMethod string) bool { + return requestMethod == "HEAD" +} + +func (t *transferWriter) shouldSendContentLength() bool { + if chunked(t.TransferEncoding) { + return false + } + if t.ContentLength > 0 { + return true + } + // Many servers expect a Content-Length for these methods + if t.Method == "POST" || t.Method == "PUT" { + return true + } + if t.ContentLength == 0 && isIdentity(t.TransferEncoding) { + return true + } + + return false +} + +func (t *transferWriter) WriteHeader(w io.Writer) error { + if t.Close { + if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil { + return err + } + } + + // Write Content-Length and/or Transfer-Encoding whose values are a + // function of the sanitized field triple (Body, ContentLength, + // TransferEncoding) + if t.shouldSendContentLength() { + if _, err := io.WriteString(w, "Content-Length: "); err != nil { + return err + } + if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil { + return err + } + } else if chunked(t.TransferEncoding) { + if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil { + return err + } + } + + // Write Trailer header + if t.Trailer != nil { + keys := make([]string, 0, len(t.Trailer)) + for k := range t.Trailer { + k = CanonicalHeaderKey(k) + switch k { + case "Transfer-Encoding", "Trailer", "Content-Length": + return &badStringError{"invalid Trailer key", k} + } + keys = append(keys, k) + } + if len(keys) > 0 { + sort.Strings(keys) + // TODO: could do better allocation-wise here, but trailers are rare, + // so being lazy for now. + if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil { + return err + } + } + } + + return nil +} + +func (t *transferWriter) WriteBody(w io.Writer) error { + var err error + var ncopy int64 + + // Write body + if t.Body != nil { + if chunked(t.TransferEncoding) { + cw := internal.NewChunkedWriter(w) + _, err = io.Copy(cw, t.Body) + if err == nil { + err = cw.Close() + } + } else if t.ContentLength == -1 { + ncopy, err = io.Copy(w, t.Body) + } else { + ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength)) + if err != nil { + return err + } + var nextra int64 + nextra, err = io.Copy(ioutil.Discard, t.Body) + ncopy += nextra + } + if err != nil { + return err + } + if err = t.BodyCloser.Close(); err != nil { + return err + } + } + + if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy { + return fmt.Errorf("http: ContentLength=%d with Body length %d", + t.ContentLength, ncopy) + } + + // TODO(petar): Place trailer writer code here. + if chunked(t.TransferEncoding) { + // Write Trailer header + if t.Trailer != nil { + if err := t.Trailer.Write(w); err != nil { + return err + } + } + // Last chunk, empty trailer + _, err = io.WriteString(w, "\r\n") + } + return err +} + +type transferReader struct { + // Input + Header Header + StatusCode int + RequestMethod string + ProtoMajor int + ProtoMinor int + // Output + Body io.ReadCloser + ContentLength int64 + TransferEncoding []string + Close bool + Trailer Header +} + +// bodyAllowedForStatus reports whether a given response status code +// permits a body. See RFC2616, section 4.4. +func bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + +var ( + suppressedHeaders304 = []string{"Content-Type", "Content-Length", "Transfer-Encoding"} + suppressedHeadersNoBody = []string{"Content-Length", "Transfer-Encoding"} +) + +func suppressedHeaders(status int) []string { + switch { + case status == 304: + // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers" + return suppressedHeaders304 + case !bodyAllowedForStatus(status): + return suppressedHeadersNoBody + } + return nil +} + +// msg is *Request or *Response. +func readTransfer(msg interface{}, r *bufio.Reader) (err error) { + t := &transferReader{RequestMethod: "GET"} + + // Unify input + isResponse := false + switch rr := msg.(type) { + case *Response: + t.Header = rr.Header + t.StatusCode = rr.StatusCode + t.ProtoMajor = rr.ProtoMajor + t.ProtoMinor = rr.ProtoMinor + t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header, true) + isResponse = true + if rr.Request != nil { + t.RequestMethod = rr.Request.Method + } + case *Request: + t.Header = rr.Header + t.ProtoMajor = rr.ProtoMajor + t.ProtoMinor = rr.ProtoMinor + // Transfer semantics for Requests are exactly like those for + // Responses with status code 200, responding to a GET method + t.StatusCode = 200 + default: + panic("unexpected type") + } + + // Default to HTTP/1.1 + if t.ProtoMajor == 0 && t.ProtoMinor == 0 { + t.ProtoMajor, t.ProtoMinor = 1, 1 + } + + // Transfer encoding, content length + t.TransferEncoding, err = fixTransferEncoding(t.RequestMethod, t.Header) + if err != nil { + return err + } + + realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding) + if err != nil { + return err + } + if isResponse && t.RequestMethod == "HEAD" { + if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil { + return err + } else { + t.ContentLength = n + } + } else { + t.ContentLength = realLength + } + + // Trailer + t.Trailer, err = fixTrailer(t.Header, t.TransferEncoding) + if err != nil { + return err + } + + // If there is no Content-Length or chunked Transfer-Encoding on a *Response + // and the status is not 1xx, 204 or 304, then the body is unbounded. + // See RFC2616, section 4.4. + switch msg.(type) { + case *Response: + if realLength == -1 && + !chunked(t.TransferEncoding) && + bodyAllowedForStatus(t.StatusCode) { + // Unbounded body. + t.Close = true + } + } + + // Prepare body reader. ContentLength < 0 means chunked encoding + // or close connection when finished, since multipart is not supported yet + switch { + case chunked(t.TransferEncoding): + if noBodyExpected(t.RequestMethod) { + t.Body = eofReader + } else { + t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} + } + case realLength == 0: + t.Body = eofReader + case realLength > 0: + t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} + default: + // realLength < 0, i.e. "Content-Length" not mentioned in header + if t.Close { + // Close semantics (i.e. HTTP/1.0) + t.Body = &body{src: r, closing: t.Close} + } else { + // Persistent connection (i.e. HTTP/1.1) + t.Body = eofReader + } + } + + // Unify output + switch rr := msg.(type) { + case *Request: + rr.Body = t.Body + rr.ContentLength = t.ContentLength + rr.TransferEncoding = t.TransferEncoding + rr.Close = t.Close + rr.Trailer = t.Trailer + case *Response: + rr.Body = t.Body + rr.ContentLength = t.ContentLength + rr.TransferEncoding = t.TransferEncoding + rr.Close = t.Close + rr.Trailer = t.Trailer + } + + return nil +} + +// Checks whether chunked is part of the encodings stack +func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } + +// Checks whether the encoding is explicitly "identity". +func isIdentity(te []string) bool { return len(te) == 1 && te[0] == "identity" } + +// Sanitize transfer encoding +func fixTransferEncoding(requestMethod string, header Header) ([]string, error) { + raw, present := header["Transfer-Encoding"] + if !present { + return nil, nil + } + + delete(header, "Transfer-Encoding") + + encodings := strings.Split(raw[0], ",") + te := make([]string, 0, len(encodings)) + // TODO: Even though we only support "identity" and "chunked" + // encodings, the loop below is designed with foresight. One + // invariant that must be maintained is that, if present, + // chunked encoding must always come first. + for _, encoding := range encodings { + encoding = strings.ToLower(strings.TrimSpace(encoding)) + // "identity" encoding is not recorded + if encoding == "identity" { + break + } + if encoding != "chunked" { + return nil, &badStringError{"unsupported transfer encoding", encoding} + } + te = te[0 : len(te)+1] + te[len(te)-1] = encoding + } + if len(te) > 1 { + return nil, &badStringError{"too many transfer encodings", strings.Join(te, ",")} + } + if len(te) > 0 { + // Chunked encoding trumps Content-Length. See RFC 2616 + // Section 4.4. Currently len(te) > 0 implies chunked + // encoding. + delete(header, "Content-Length") + return te, nil + } + + return nil, nil +} + +// Determine the expected body length, using RFC 2616 Section 4.4. This +// function is not a method, because ultimately it should be shared by +// ReadResponse and ReadRequest. +func fixLength(isResponse bool, status int, requestMethod string, header Header, te []string) (int64, error) { + + // Logic based on response type or status + if noBodyExpected(requestMethod) { + return 0, nil + } + if status/100 == 1 { + return 0, nil + } + switch status { + case 204, 304: + return 0, nil + } + + // Logic based on Transfer-Encoding + if chunked(te) { + return -1, nil + } + + // Logic based on Content-Length + cl := strings.TrimSpace(header.get("Content-Length")) + if cl != "" { + n, err := parseContentLength(cl) + if err != nil { + return -1, err + } + return n, nil + } else { + header.Del("Content-Length") + } + + if !isResponse && requestMethod == "GET" { + // RFC 2616 doesn't explicitly permit nor forbid an + // entity-body on a GET request so we permit one if + // declared, but we default to 0 here (not -1 below) + // if there's no mention of a body. + return 0, nil + } + + // Body-EOF logic based on other methods (like closing, or chunked coding) + return -1, nil +} + +// Determine whether to hang up after sending a request and body, or +// receiving a response and body +// 'header' is the request headers +func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool { + if major < 1 { + return true + } else if major == 1 && minor == 0 { + if !strings.Contains(strings.ToLower(header.get("Connection")), "keep-alive") { + return true + } + return false + } else { + // TODO: Should split on commas, toss surrounding white space, + // and check each field. + if strings.ToLower(header.get("Connection")) == "close" { + if removeCloseHeader { + header.Del("Connection") + } + return true + } + } + return false +} + +// Parse the trailer header +func fixTrailer(header Header, te []string) (Header, error) { + raw := header.get("Trailer") + if raw == "" { + return nil, nil + } + + header.Del("Trailer") + trailer := make(Header) + keys := strings.Split(raw, ",") + for _, key := range keys { + key = CanonicalHeaderKey(strings.TrimSpace(key)) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + return nil, &badStringError{"bad trailer key", key} + } + trailer[key] = nil + } + if len(trailer) == 0 { + return nil, nil + } + if !chunked(te) { + // Trailer and no chunking + return nil, ErrUnexpectedTrailer + } + return trailer, nil +} + +// body turns a Reader into a ReadCloser. +// Close ensures that the body has been fully read +// and then reads the trailer if necessary. +type body struct { + src io.Reader + hdr interface{} // non-nil (Response or Request) value means read trailer + r *bufio.Reader // underlying wire-format reader for the trailer + closing bool // is the connection to be closed after reading body? + + mu sync.Mutex // guards closed, and calls to Read and Close + closed bool +} + +// ErrBodyReadAfterClose is returned when reading a Request or Response +// Body after the body has been closed. This typically happens when the body is +// read after an HTTP Handler calls WriteHeader or Write on its +// ResponseWriter. +var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body") + +func (b *body) Read(p []byte) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return 0, ErrBodyReadAfterClose + } + return b.readLocked(p) +} + +// Must hold b.mu. +func (b *body) readLocked(p []byte) (n int, err error) { + n, err = b.src.Read(p) + + if err == io.EOF { + // Chunked case. Read the trailer. + if b.hdr != nil { + if e := b.readTrailer(); e != nil { + err = e + } + b.hdr = nil + } else { + // If the server declared the Content-Length, our body is a LimitedReader + // and we need to check whether this EOF arrived early. + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 { + err = io.ErrUnexpectedEOF + } + } + } + + // If we can return an EOF here along with the read data, do + // so. This is optional per the io.Reader contract, but doing + // so helps the HTTP transport code recycle its connection + // earlier (since it will see this EOF itself), even if the + // client doesn't do future reads or Close. + if err == nil && n > 0 { + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N == 0 { + err = io.EOF + } + } + + return n, err +} + +var ( + singleCRLF = []byte("\r\n") + doubleCRLF = []byte("\r\n\r\n") +) + +func seeUpcomingDoubleCRLF(r *bufio.Reader) bool { + for peekSize := 4; ; peekSize++ { + // This loop stops when Peek returns an error, + // which it does when r's buffer has been filled. + buf, err := r.Peek(peekSize) + if bytes.HasSuffix(buf, doubleCRLF) { + return true + } + if err != nil { + break + } + } + return false +} + +var errTrailerEOF = errors.New("http: unexpected EOF reading trailer") + +func (b *body) readTrailer() error { + // The common case, since nobody uses trailers. + buf, err := b.r.Peek(2) + if bytes.Equal(buf, singleCRLF) { + b.r.ReadByte() + b.r.ReadByte() + return nil + } + if len(buf) < 2 { + return errTrailerEOF + } + if err != nil { + return err + } + + // Make sure there's a header terminator coming up, to prevent + // a DoS with an unbounded size Trailer. It's not easy to + // slip in a LimitReader here, as textproto.NewReader requires + // a concrete *bufio.Reader. Also, we can't get all the way + // back up to our conn's LimitedReader that *might* be backing + // this bufio.Reader. Instead, a hack: we iteratively Peek up + // to the bufio.Reader's max size, looking for a double CRLF. + // This limits the trailer to the underlying buffer size, typically 4kB. + if !seeUpcomingDoubleCRLF(b.r) { + return errors.New("http: suspiciously long trailer after chunked body") + } + + hdr, err := textproto.NewReader(b.r).ReadMIMEHeader() + if err != nil { + if err == io.EOF { + return errTrailerEOF + } + return err + } + switch rr := b.hdr.(type) { + case *Request: + mergeSetHeader(&rr.Trailer, Header(hdr)) + case *Response: + mergeSetHeader(&rr.Trailer, Header(hdr)) + } + return nil +} + +func mergeSetHeader(dst *Header, src Header) { + if *dst == nil { + *dst = src + return + } + for k, vv := range src { + (*dst)[k] = vv + } +} + +func (b *body) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return nil + } + var err error + switch { + case b.hdr == nil && b.closing: + // no trailer and closing the connection next. + // no point in reading to EOF. + default: + // Fully consume the body, which will also lead to us reading + // the trailer headers after the body, if present. + _, err = io.Copy(ioutil.Discard, bodyLocked{b}) + } + b.closed = true + return err +} + +// bodyLocked is a io.Reader reading from a *body when its mutex is +// already held. +type bodyLocked struct { + b *body +} + +func (bl bodyLocked) Read(p []byte) (n int, err error) { + if bl.b.closed { + return 0, ErrBodyReadAfterClose + } + return bl.b.readLocked(p) +} + +// parseContentLength trims whitespace from s and returns -1 if no value +// is set, or the value if it's >= 0. +func parseContentLength(cl string) (int64, error) { + cl = strings.TrimSpace(cl) + if cl == "" { + return -1, nil + } + n, err := strconv.ParseInt(cl, 10, 64) + if err != nil || n < 0 { + return 0, &badStringError{"bad Content-Length", cl} + } + return n, nil + +} diff --git a/src/net/http/transfer_test.go b/src/net/http/transfer_test.go new file mode 100644 index 000000000..48cd540b9 --- /dev/null +++ b/src/net/http/transfer_test.go @@ -0,0 +1,64 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bufio" + "io" + "strings" + "testing" +) + +func TestBodyReadBadTrailer(t *testing.T) { + b := &body{ + src: strings.NewReader("foobar"), + hdr: true, // force reading the trailer + r: bufio.NewReader(strings.NewReader("")), + } + buf := make([]byte, 7) + n, err := b.Read(buf[:3]) + got := string(buf[:n]) + if got != "foo" || err != nil { + t.Fatalf(`first Read = %d (%q), %v; want 3 ("foo")`, n, got, err) + } + + n, err = b.Read(buf[:]) + got = string(buf[:n]) + if got != "bar" || err != nil { + t.Fatalf(`second Read = %d (%q), %v; want 3 ("bar")`, n, got, err) + } + + n, err = b.Read(buf[:]) + got = string(buf[:n]) + if err == nil { + t.Errorf("final Read was successful (%q), expected error from trailer read", got) + } +} + +func TestFinalChunkedBodyReadEOF(t *testing.T) { + res, err := ReadResponse(bufio.NewReader(strings.NewReader( + "HTTP/1.1 200 OK\r\n"+ + "Transfer-Encoding: chunked\r\n"+ + "\r\n"+ + "0a\r\n"+ + "Body here\n\r\n"+ + "09\r\n"+ + "continued\r\n"+ + "0\r\n"+ + "\r\n")), nil) + if err != nil { + t.Fatal(err) + } + want := "Body here\ncontinued" + buf := make([]byte, len(want)) + n, err := res.Body.Read(buf) + if n != len(want) || err != io.EOF { + t.Logf("body = %#v", res.Body) + t.Errorf("Read = %v, %v; want %d, EOF", n, err, len(want)) + } + if string(buf) != want { + t.Errorf("buf = %q; want %q", buf, want) + } +} diff --git a/src/net/http/transport.go b/src/net/http/transport.go new file mode 100644 index 000000000..782f7cd39 --- /dev/null +++ b/src/net/http/transport.go @@ -0,0 +1,1275 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP client implementation. See RFC 2616. +// +// This is the low-level Transport implementation of RoundTripper. +// The high-level interface is in client.go. + +package http + +import ( + "bufio" + "compress/gzip" + "crypto/tls" + "errors" + "fmt" + "io" + "log" + "net" + "net/url" + "os" + "strings" + "sync" + "time" +) + +// DefaultTransport is the default implementation of Transport and is +// used by DefaultClient. It establishes network connections as needed +// and caches them for reuse by subsequent calls. It uses HTTP proxies +// as directed by the $HTTP_PROXY and $NO_PROXY (or $http_proxy and +// $no_proxy) environment variables. +var DefaultTransport RoundTripper = &Transport{ + Proxy: ProxyFromEnvironment, + Dial: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial, + TLSHandshakeTimeout: 10 * time.Second, +} + +// DefaultMaxIdleConnsPerHost is the default value of Transport's +// MaxIdleConnsPerHost. +const DefaultMaxIdleConnsPerHost = 2 + +// Transport is an implementation of RoundTripper that supports HTTP, +// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT). +// Transport can also cache connections for future re-use. +type Transport struct { + idleMu sync.Mutex + wantIdle bool // user has requested to close all idle conns + idleConn map[connectMethodKey][]*persistConn + idleConnCh map[connectMethodKey]chan *persistConn + + reqMu sync.Mutex + reqCanceler map[*Request]func() + + altMu sync.RWMutex + altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper + + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*Request) (*url.URL, error) + + // Dial specifies the dial function for creating unencrypted + // TCP connections. + // If Dial is nil, net.Dial is used. + Dial func(network, addr string) (net.Conn, error) + + // DialTLS specifies an optional dial function for creating + // TLS connections for non-proxied HTTPS requests. + // + // If DialTLS is nil, Dial and TLSClientConfig are used. + // + // If DialTLS is set, the Dial hook is not used for HTTPS + // requests and the TLSClientConfig and TLSHandshakeTimeout + // are ignored. The returned net.Conn is assumed to already be + // past the TLS handshake. + DialTLS func(network, addr string) (net.Conn, error) + + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. If nil, the default configuration is used. + TLSClientConfig *tls.Config + + // TLSHandshakeTimeout specifies the maximum amount of time waiting to + // wait for a TLS handshake. Zero means no timeout. + TLSHandshakeTimeout time.Duration + + // DisableKeepAlives, if true, prevents re-use of TCP connections + // between different HTTP requests. + DisableKeepAlives bool + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. + DisableCompression bool + + // MaxIdleConnsPerHost, if non-zero, controls the maximum idle + // (keep-alive) to keep per-host. If zero, + // DefaultMaxIdleConnsPerHost is used. + MaxIdleConnsPerHost int + + // ResponseHeaderTimeout, if non-zero, specifies the amount of + // time to wait for a server's response headers after fully + // writing the request (including its body, if any). This + // time does not include the time to read the response body. + ResponseHeaderTimeout time.Duration + + // TODO: tunable on global max cached connections + // TODO: tunable on timeout on cached connections +} + +// ProxyFromEnvironment returns the URL of the proxy to use for a +// given request, as indicated by the environment variables +// HTTP_PROXY, HTTPS_PROXY and NO_PROXY (or the lowercase versions +// thereof). HTTPS_PROXY takes precedence over HTTP_PROXY for https +// requests. +// +// The environment values may be either a complete URL or a +// "host[:port]", in which case the "http" scheme is assumed. +// An error is returned if the value is a different form. +// +// A nil URL and nil error are returned if no proxy is defined in the +// environment, or a proxy should not be used for the given request, +// as defined by NO_PROXY. +// +// As a special case, if req.URL.Host is "localhost" (with or without +// a port number), then a nil URL and nil error will be returned. +func ProxyFromEnvironment(req *Request) (*url.URL, error) { + var proxy string + if req.URL.Scheme == "https" { + proxy = httpsProxyEnv.Get() + } + if proxy == "" { + proxy = httpProxyEnv.Get() + } + if proxy == "" { + return nil, nil + } + if !useProxy(canonicalAddr(req.URL)) { + return nil, nil + } + proxyURL, err := url.Parse(proxy) + if err != nil || !strings.HasPrefix(proxyURL.Scheme, "http") { + // proxy was bogus. Try prepending "http://" to it and + // see if that parses correctly. If not, we fall + // through and complain about the original one. + if proxyURL, err := url.Parse("http://" + proxy); err == nil { + return proxyURL, nil + } + } + if err != nil { + return nil, fmt.Errorf("invalid proxy address %q: %v", proxy, err) + } + return proxyURL, nil +} + +// ProxyURL returns a proxy function (for use in a Transport) +// that always returns the same URL. +func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) { + return func(*Request) (*url.URL, error) { + return fixedURL, nil + } +} + +// transportRequest is a wrapper around a *Request that adds +// optional extra headers to write. +type transportRequest struct { + *Request // original request, not to be mutated + extra Header // extra headers to write, or nil +} + +func (tr *transportRequest) extraHeaders() Header { + if tr.extra == nil { + tr.extra = make(Header) + } + return tr.extra +} + +// RoundTrip implements the RoundTripper interface. +// +// For higher-level HTTP client support (such as handling of cookies +// and redirects), see Get, Post, and the Client type. +func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { + if req.URL == nil { + req.closeBody() + return nil, errors.New("http: nil Request.URL") + } + if req.Header == nil { + req.closeBody() + return nil, errors.New("http: nil Request.Header") + } + if req.URL.Scheme != "http" && req.URL.Scheme != "https" { + t.altMu.RLock() + var rt RoundTripper + if t.altProto != nil { + rt = t.altProto[req.URL.Scheme] + } + t.altMu.RUnlock() + if rt == nil { + req.closeBody() + return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} + } + return rt.RoundTrip(req) + } + if req.URL.Host == "" { + req.closeBody() + return nil, errors.New("http: no Host in request URL") + } + treq := &transportRequest{Request: req} + cm, err := t.connectMethodForRequest(treq) + if err != nil { + req.closeBody() + return nil, err + } + + // Get the cached or newly-created connection to either the + // host (for http or https), the http proxy, or the http proxy + // pre-CONNECTed to https server. In any case, we'll be ready + // to send it requests. + pconn, err := t.getConn(req, cm) + if err != nil { + t.setReqCanceler(req, nil) + req.closeBody() + return nil, err + } + + return pconn.roundTrip(treq) +} + +// RegisterProtocol registers a new protocol with scheme. +// The Transport will pass requests using the given scheme to rt. +// It is rt's responsibility to simulate HTTP request semantics. +// +// RegisterProtocol can be used by other packages to provide +// implementations of protocol schemes like "ftp" or "file". +func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { + if scheme == "http" || scheme == "https" { + panic("protocol " + scheme + " already registered") + } + t.altMu.Lock() + defer t.altMu.Unlock() + if t.altProto == nil { + t.altProto = make(map[string]RoundTripper) + } + if _, exists := t.altProto[scheme]; exists { + panic("protocol " + scheme + " already registered") + } + t.altProto[scheme] = rt +} + +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle in +// a "keep-alive" state. It does not interrupt any connections currently +// in use. +func (t *Transport) CloseIdleConnections() { + t.idleMu.Lock() + m := t.idleConn + t.idleConn = nil + t.idleConnCh = nil + t.wantIdle = true + t.idleMu.Unlock() + for _, conns := range m { + for _, pconn := range conns { + pconn.close() + } + } +} + +// CancelRequest cancels an in-flight request by closing its +// connection. +func (t *Transport) CancelRequest(req *Request) { + t.reqMu.Lock() + cancel := t.reqCanceler[req] + t.reqMu.Unlock() + if cancel != nil { + cancel() + } +} + +// +// Private implementation past this point. +// + +var ( + httpProxyEnv = &envOnce{ + names: []string{"HTTP_PROXY", "http_proxy"}, + } + httpsProxyEnv = &envOnce{ + names: []string{"HTTPS_PROXY", "https_proxy"}, + } + noProxyEnv = &envOnce{ + names: []string{"NO_PROXY", "no_proxy"}, + } +) + +// envOnce looks up an environment variable (optionally by multiple +// names) once. It mitigates expensive lookups on some platforms +// (e.g. Windows). +type envOnce struct { + names []string + once sync.Once + val string +} + +func (e *envOnce) Get() string { + e.once.Do(e.init) + return e.val +} + +func (e *envOnce) init() { + for _, n := range e.names { + e.val = os.Getenv(n) + if e.val != "" { + return + } + } +} + +// reset is used by tests +func (e *envOnce) reset() { + e.once = sync.Once{} + e.val = "" +} + +func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { + cm.targetScheme = treq.URL.Scheme + cm.targetAddr = canonicalAddr(treq.URL) + if t.Proxy != nil { + cm.proxyURL, err = t.Proxy(treq.Request) + } + return cm, err +} + +// proxyAuth returns the Proxy-Authorization header to set +// on requests, if applicable. +func (cm *connectMethod) proxyAuth() string { + if cm.proxyURL == nil { + return "" + } + if u := cm.proxyURL.User; u != nil { + username := u.Username() + password, _ := u.Password() + return "Basic " + basicAuth(username, password) + } + return "" +} + +// putIdleConn adds pconn to the list of idle persistent connections awaiting +// a new request. +// If pconn is no longer needed or not in a good state, putIdleConn +// returns false. +func (t *Transport) putIdleConn(pconn *persistConn) bool { + if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 { + pconn.close() + return false + } + if pconn.isBroken() { + return false + } + key := pconn.cacheKey + max := t.MaxIdleConnsPerHost + if max == 0 { + max = DefaultMaxIdleConnsPerHost + } + t.idleMu.Lock() + + waitingDialer := t.idleConnCh[key] + select { + case waitingDialer <- pconn: + // We're done with this pconn and somebody else is + // currently waiting for a conn of this type (they're + // actively dialing, but this conn is ready + // first). Chrome calls this socket late binding. See + // https://insouciant.org/tech/connection-management-in-chromium/ + t.idleMu.Unlock() + return true + default: + if waitingDialer != nil { + // They had populated this, but their dial won + // first, so we can clean up this map entry. + delete(t.idleConnCh, key) + } + } + if t.wantIdle { + t.idleMu.Unlock() + pconn.close() + return false + } + if t.idleConn == nil { + t.idleConn = make(map[connectMethodKey][]*persistConn) + } + if len(t.idleConn[key]) >= max { + t.idleMu.Unlock() + pconn.close() + return false + } + for _, exist := range t.idleConn[key] { + if exist == pconn { + log.Fatalf("dup idle pconn %p in freelist", pconn) + } + } + t.idleConn[key] = append(t.idleConn[key], pconn) + t.idleMu.Unlock() + return true +} + +// getIdleConnCh returns a channel to receive and return idle +// persistent connection for the given connectMethod. +// It may return nil, if persistent connections are not being used. +func (t *Transport) getIdleConnCh(cm connectMethod) chan *persistConn { + if t.DisableKeepAlives { + return nil + } + key := cm.key() + t.idleMu.Lock() + defer t.idleMu.Unlock() + t.wantIdle = false + if t.idleConnCh == nil { + t.idleConnCh = make(map[connectMethodKey]chan *persistConn) + } + ch, ok := t.idleConnCh[key] + if !ok { + ch = make(chan *persistConn) + t.idleConnCh[key] = ch + } + return ch +} + +func (t *Transport) getIdleConn(cm connectMethod) (pconn *persistConn) { + key := cm.key() + t.idleMu.Lock() + defer t.idleMu.Unlock() + if t.idleConn == nil { + return nil + } + for { + pconns, ok := t.idleConn[key] + if !ok { + return nil + } + if len(pconns) == 1 { + pconn = pconns[0] + delete(t.idleConn, key) + } else { + // 2 or more cached connections; pop last + // TODO: queue? + pconn = pconns[len(pconns)-1] + t.idleConn[key] = pconns[:len(pconns)-1] + } + if !pconn.isBroken() { + return + } + } +} + +func (t *Transport) setReqCanceler(r *Request, fn func()) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + if t.reqCanceler == nil { + t.reqCanceler = make(map[*Request]func()) + } + if fn != nil { + t.reqCanceler[r] = fn + } else { + delete(t.reqCanceler, r) + } +} + +func (t *Transport) dial(network, addr string) (c net.Conn, err error) { + if t.Dial != nil { + return t.Dial(network, addr) + } + return net.Dial(network, addr) +} + +// Testing hooks: +var prePendingDial, postPendingDial func() + +// getConn dials and creates a new persistConn to the target as +// specified in the connectMethod. This includes doing a proxy CONNECT +// and/or setting up TLS. If this doesn't return an error, the persistConn +// is ready to write requests to. +func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) { + if pc := t.getIdleConn(cm); pc != nil { + return pc, nil + } + + type dialRes struct { + pc *persistConn + err error + } + dialc := make(chan dialRes) + + handlePendingDial := func() { + if prePendingDial != nil { + prePendingDial() + } + go func() { + if v := <-dialc; v.err == nil { + t.putIdleConn(v.pc) + } + if postPendingDial != nil { + postPendingDial() + } + }() + } + + cancelc := make(chan struct{}) + t.setReqCanceler(req, func() { close(cancelc) }) + + go func() { + pc, err := t.dialConn(cm) + dialc <- dialRes{pc, err} + }() + + idleConnCh := t.getIdleConnCh(cm) + select { + case v := <-dialc: + // Our dial finished. + return v.pc, v.err + case pc := <-idleConnCh: + // Another request finished first and its net.Conn + // became available before our dial. Or somebody + // else's dial that they didn't use. + // But our dial is still going, so give it away + // when it finishes: + handlePendingDial() + return pc, nil + case <-cancelc: + handlePendingDial() + return nil, errors.New("net/http: request canceled while waiting for connection") + } +} + +func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { + pconn := &persistConn{ + t: t, + cacheKey: cm.key(), + reqch: make(chan requestAndChan, 1), + writech: make(chan writeRequest, 1), + closech: make(chan struct{}), + writeErrCh: make(chan error, 1), + } + tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil + if tlsDial { + var err error + pconn.conn, err = t.DialTLS("tcp", cm.addr()) + if err != nil { + return nil, err + } + if tc, ok := pconn.conn.(*tls.Conn); ok { + cs := tc.ConnectionState() + pconn.tlsState = &cs + } + } else { + conn, err := t.dial("tcp", cm.addr()) + if err != nil { + if cm.proxyURL != nil { + err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err) + } + return nil, err + } + pconn.conn = conn + } + + // Proxy setup. + switch { + case cm.proxyURL == nil: + // Do nothing. Not using a proxy. + case cm.targetScheme == "http": + pconn.isProxy = true + if pa := cm.proxyAuth(); pa != "" { + pconn.mutateHeaderFunc = func(h Header) { + h.Set("Proxy-Authorization", pa) + } + } + case cm.targetScheme == "https": + conn := pconn.conn + connectReq := &Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: cm.targetAddr}, + Host: cm.targetAddr, + Header: make(Header), + } + if pa := cm.proxyAuth(); pa != "" { + connectReq.Header.Set("Proxy-Authorization", pa) + } + connectReq.Write(conn) + + // Read response. + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(conn) + resp, err := ReadResponse(br, connectReq) + if err != nil { + conn.Close() + return nil, err + } + if resp.StatusCode != 200 { + f := strings.SplitN(resp.Status, " ", 2) + conn.Close() + return nil, errors.New(f[1]) + } + } + + if cm.targetScheme == "https" && !tlsDial { + // Initiate TLS and check remote host name against certificate. + cfg := t.TLSClientConfig + if cfg == nil || cfg.ServerName == "" { + host := cm.tlsHost() + if cfg == nil { + cfg = &tls.Config{ServerName: host} + } else { + clone := *cfg // shallow clone + clone.ServerName = host + cfg = &clone + } + } + plainConn := pconn.conn + tlsConn := tls.Client(plainConn, cfg) + errc := make(chan error, 2) + var timer *time.Timer // for canceling TLS handshake + if d := t.TLSHandshakeTimeout; d != 0 { + timer = time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + } + go func() { + err := tlsConn.Handshake() + if timer != nil { + timer.Stop() + } + errc <- err + }() + if err := <-errc; err != nil { + plainConn.Close() + return nil, err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + plainConn.Close() + return nil, err + } + } + cs := tlsConn.ConnectionState() + pconn.tlsState = &cs + pconn.conn = tlsConn + } + + pconn.br = bufio.NewReader(noteEOFReader{pconn.conn, &pconn.sawEOF}) + pconn.bw = bufio.NewWriter(pconn.conn) + go pconn.readLoop() + go pconn.writeLoop() + return pconn, nil +} + +// useProxy returns true if requests to addr should use a proxy, +// according to the NO_PROXY or no_proxy environment variable. +// addr is always a canonicalAddr with a host and port. +func useProxy(addr string) bool { + if len(addr) == 0 { + return true + } + host, _, err := net.SplitHostPort(addr) + if err != nil { + return false + } + if host == "localhost" { + return false + } + if ip := net.ParseIP(host); ip != nil { + if ip.IsLoopback() { + return false + } + } + + no_proxy := noProxyEnv.Get() + if no_proxy == "*" { + return false + } + + addr = strings.ToLower(strings.TrimSpace(addr)) + if hasPort(addr) { + addr = addr[:strings.LastIndex(addr, ":")] + } + + for _, p := range strings.Split(no_proxy, ",") { + p = strings.ToLower(strings.TrimSpace(p)) + if len(p) == 0 { + continue + } + if hasPort(p) { + p = p[:strings.LastIndex(p, ":")] + } + if addr == p { + return false + } + if p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:]) { + // no_proxy ".foo.com" matches "bar.foo.com" or "foo.com" + return false + } + if p[0] != '.' && strings.HasSuffix(addr, p) && addr[len(addr)-len(p)-1] == '.' { + // no_proxy "foo.com" matches "bar.foo.com" + return false + } + } + return true +} + +// connectMethod is the map key (in its String form) for keeping persistent +// TCP connections alive for subsequent HTTP requests. +// +// A connect method may be of the following types: +// +// Cache key form Description +// ----------------- ------------------------- +// |http|foo.com http directly to server, no proxy +// |https|foo.com https directly to server, no proxy +// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com +// http://proxy.com|http http to proxy, http to anywhere after that +// +// Note: no support to https to the proxy yet. +// +type connectMethod struct { + proxyURL *url.URL // nil for no proxy, else full proxy URL + targetScheme string // "http" or "https" + targetAddr string // Not used if proxy + http targetScheme (4th example in table) +} + +func (cm *connectMethod) key() connectMethodKey { + proxyStr := "" + targetAddr := cm.targetAddr + if cm.proxyURL != nil { + proxyStr = cm.proxyURL.String() + if cm.targetScheme == "http" { + targetAddr = "" + } + } + return connectMethodKey{ + proxy: proxyStr, + scheme: cm.targetScheme, + addr: targetAddr, + } +} + +// addr returns the first hop "host:port" to which we need to TCP connect. +func (cm *connectMethod) addr() string { + if cm.proxyURL != nil { + return canonicalAddr(cm.proxyURL) + } + return cm.targetAddr +} + +// tlsHost returns the host name to match against the peer's +// TLS certificate. +func (cm *connectMethod) tlsHost() string { + h := cm.targetAddr + if hasPort(h) { + h = h[:strings.LastIndex(h, ":")] + } + return h +} + +// connectMethodKey is the map key version of connectMethod, with a +// stringified proxy URL (or the empty string) instead of a pointer to +// a URL. +type connectMethodKey struct { + proxy, scheme, addr string +} + +func (k connectMethodKey) String() string { + // Only used by tests. + return fmt.Sprintf("%s|%s|%s", k.proxy, k.scheme, k.addr) +} + +// persistConn wraps a connection, usually a persistent one +// (but may be used for non-keep-alive requests as well) +type persistConn struct { + t *Transport + cacheKey connectMethodKey + conn net.Conn + tlsState *tls.ConnectionState + br *bufio.Reader // from conn + sawEOF bool // whether we've seen EOF from conn; owned by readLoop + bw *bufio.Writer // to conn + reqch chan requestAndChan // written by roundTrip; read by readLoop + writech chan writeRequest // written by roundTrip; read by writeLoop + closech chan struct{} // closed when conn closed + isProxy bool + // writeErrCh passes the request write error (usually nil) + // from the writeLoop goroutine to the readLoop which passes + // it off to the res.Body reader, which then uses it to decide + // whether or not a connection can be reused. Issue 7569. + writeErrCh chan error + + lk sync.Mutex // guards following fields + numExpectedResponses int + closed bool // whether conn has been closed + broken bool // an error has happened on this connection; marked broken so it's not reused. + // mutateHeaderFunc is an optional func to modify extra + // headers on each outbound request before it's written. (the + // original Request given to RoundTrip is not modified) + mutateHeaderFunc func(Header) +} + +// isBroken reports whether this connection is in a known broken state. +func (pc *persistConn) isBroken() bool { + pc.lk.Lock() + b := pc.broken + pc.lk.Unlock() + return b +} + +func (pc *persistConn) cancelRequest() { + pc.conn.Close() +} + +var remoteSideClosedFunc func(error) bool // or nil to use default + +func remoteSideClosed(err error) bool { + if err == io.EOF { + return true + } + if remoteSideClosedFunc != nil { + return remoteSideClosedFunc(err) + } + return false +} + +func (pc *persistConn) readLoop() { + alive := true + + for alive { + pb, err := pc.br.Peek(1) + + pc.lk.Lock() + if pc.numExpectedResponses == 0 { + if !pc.closed { + pc.closeLocked() + if len(pb) > 0 { + log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", + string(pb), err) + } + } + pc.lk.Unlock() + return + } + pc.lk.Unlock() + + rc := <-pc.reqch + + var resp *Response + if err == nil { + resp, err = ReadResponse(pc.br, rc.req) + if err == nil && resp.StatusCode == 100 { + // Skip any 100-continue for now. + // TODO(bradfitz): if rc.req had "Expect: 100-continue", + // actually block the request body write and signal the + // writeLoop now to begin sending it. (Issue 2184) For now we + // eat it, since we're never expecting one. + resp, err = ReadResponse(pc.br, rc.req) + } + } + + if resp != nil { + resp.TLS = pc.tlsState + } + + hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0 + + if err != nil { + pc.close() + } else { + if rc.addedGzip && hasBody && resp.Header.Get("Content-Encoding") == "gzip" { + resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Length") + resp.ContentLength = -1 + resp.Body = &gzipReader{body: resp.Body} + } + resp.Body = &bodyEOFSignal{body: resp.Body} + } + + if err != nil || resp.Close || rc.req.Close || resp.StatusCode <= 199 { + // Don't do keep-alive on error if either party requested a close + // or we get an unexpected informational (1xx) response. + // StatusCode 100 is already handled above. + alive = false + } + + var waitForBodyRead chan bool + if hasBody { + waitForBodyRead = make(chan bool, 2) + resp.Body.(*bodyEOFSignal).earlyCloseFn = func() error { + // Sending false here sets alive to + // false and closes the connection + // below. + waitForBodyRead <- false + return nil + } + resp.Body.(*bodyEOFSignal).fn = func(err error) { + waitForBodyRead <- alive && + err == nil && + !pc.sawEOF && + pc.wroteRequest() && + pc.t.putIdleConn(pc) + } + } + + if alive && !hasBody { + alive = !pc.sawEOF && + pc.wroteRequest() && + pc.t.putIdleConn(pc) + } + + rc.ch <- responseAndError{resp, err} + + // Wait for the just-returned response body to be fully consumed + // before we race and peek on the underlying bufio reader. + if waitForBodyRead != nil { + select { + case alive = <-waitForBodyRead: + case <-pc.closech: + alive = false + } + } + + pc.t.setReqCanceler(rc.req, nil) + + if !alive { + pc.close() + } + } +} + +func (pc *persistConn) writeLoop() { + for { + select { + case wr := <-pc.writech: + if pc.isBroken() { + wr.ch <- errors.New("http: can't write HTTP request on broken connection") + continue + } + err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra) + if err == nil { + err = pc.bw.Flush() + } + if err != nil { + pc.markBroken() + wr.req.Request.closeBody() + } + pc.writeErrCh <- err // to the body reader, which might recycle us + wr.ch <- err // to the roundTrip function + case <-pc.closech: + return + } + } +} + +// wroteRequest is a check before recycling a connection that the previous write +// (from writeLoop above) happened and was successful. +func (pc *persistConn) wroteRequest() bool { + select { + case err := <-pc.writeErrCh: + // Common case: the write happened well before the response, so + // avoid creating a timer. + return err == nil + default: + // Rare case: the request was written in writeLoop above but + // before it could send to pc.writeErrCh, the reader read it + // all, processed it, and called us here. In this case, give the + // write goroutine a bit of time to finish its send. + // + // Less rare case: We also get here in the legitimate case of + // Issue 7569, where the writer is still writing (or stalled), + // but the server has already replied. In this case, we don't + // want to wait too long, and we want to return false so this + // connection isn't re-used. + select { + case err := <-pc.writeErrCh: + return err == nil + case <-time.After(50 * time.Millisecond): + return false + } + } +} + +type responseAndError struct { + res *Response + err error +} + +type requestAndChan struct { + req *Request + ch chan responseAndError + + // did the Transport (as opposed to the client code) add an + // Accept-Encoding gzip header? only if it we set it do + // we transparently decode the gzip. + addedGzip bool +} + +// A writeRequest is sent by the readLoop's goroutine to the +// writeLoop's goroutine to write a request while the read loop +// concurrently waits on both the write response and the server's +// reply. +type writeRequest struct { + req *transportRequest + ch chan<- error +} + +type httpError struct { + err string + timeout bool +} + +func (e *httpError) Error() string { return e.err } +func (e *httpError) Timeout() bool { return e.timeout } +func (e *httpError) Temporary() bool { return true } + +var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} +var errClosed error = &httpError{err: "net/http: transport closed before response was received"} + +func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { + pc.t.setReqCanceler(req.Request, pc.cancelRequest) + pc.lk.Lock() + pc.numExpectedResponses++ + headerFn := pc.mutateHeaderFunc + pc.lk.Unlock() + + if headerFn != nil { + headerFn(req.extraHeaders()) + } + + // Ask for a compressed version if the caller didn't set their + // own value for Accept-Encoding. We only attempt to + // uncompress the gzip stream if we were the layer that + // requested it. + requestedGzip := false + if !pc.t.DisableCompression && + req.Header.Get("Accept-Encoding") == "" && + req.Header.Get("Range") == "" && + req.Method != "HEAD" { + // Request gzip only, not deflate. Deflate is ambiguous and + // not as universally supported anyway. + // See: http://www.gzip.org/zlib/zlib_faq.html#faq38 + // + // Note that we don't request this for HEAD requests, + // due to a bug in nginx: + // http://trac.nginx.org/nginx/ticket/358 + // http://golang.org/issue/5522 + // + // We don't request gzip if the request is for a range, since + // auto-decoding a portion of a gzipped document will just fail + // anyway. See http://golang.org/issue/8923 + requestedGzip = true + req.extraHeaders().Set("Accept-Encoding", "gzip") + } + + // Write the request concurrently with waiting for a response, + // in case the server decides to reply before reading our full + // request body. + writeErrCh := make(chan error, 1) + pc.writech <- writeRequest{req, writeErrCh} + + resc := make(chan responseAndError, 1) + pc.reqch <- requestAndChan{req.Request, resc, requestedGzip} + + var re responseAndError + var pconnDeadCh = pc.closech + var failTicker <-chan time.Time + var respHeaderTimer <-chan time.Time +WaitResponse: + for { + select { + case err := <-writeErrCh: + if err != nil { + re = responseAndError{nil, err} + pc.close() + break WaitResponse + } + if d := pc.t.ResponseHeaderTimeout; d > 0 { + respHeaderTimer = time.After(d) + } + case <-pconnDeadCh: + // The persist connection is dead. This shouldn't + // usually happen (only with Connection: close responses + // with no response bodies), but if it does happen it + // means either a) the remote server hung up on us + // prematurely, or b) the readLoop sent us a response & + // closed its closech at roughly the same time, and we + // selected this case first, in which case a response + // might still be coming soon. + // + // We can't avoid the select race in b) by using a unbuffered + // resc channel instead, because then goroutines can + // leak if we exit due to other errors. + pconnDeadCh = nil // avoid spinning + failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc + case <-failTicker: + re = responseAndError{err: errClosed} + break WaitResponse + case <-respHeaderTimer: + pc.close() + re = responseAndError{err: errTimeout} + break WaitResponse + case re = <-resc: + break WaitResponse + } + } + + pc.lk.Lock() + pc.numExpectedResponses-- + pc.lk.Unlock() + + if re.err != nil { + pc.t.setReqCanceler(req.Request, nil) + } + return re.res, re.err +} + +// markBroken marks a connection as broken (so it's not reused). +// It differs from close in that it doesn't close the underlying +// connection for use when it's still being read. +func (pc *persistConn) markBroken() { + pc.lk.Lock() + defer pc.lk.Unlock() + pc.broken = true +} + +func (pc *persistConn) close() { + pc.lk.Lock() + defer pc.lk.Unlock() + pc.closeLocked() +} + +func (pc *persistConn) closeLocked() { + pc.broken = true + if !pc.closed { + pc.conn.Close() + pc.closed = true + close(pc.closech) + } + pc.mutateHeaderFunc = nil +} + +var portMap = map[string]string{ + "http": "80", + "https": "443", +} + +// canonicalAddr returns url.Host but always with a ":port" suffix +func canonicalAddr(url *url.URL) string { + addr := url.Host + if !hasPort(addr) { + return addr + ":" + portMap[url.Scheme] + } + return addr +} + +// bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most +// once, right before its final (error-producing) Read or Close call +// returns. If earlyCloseFn is non-nil and Close is called before +// io.EOF is seen, earlyCloseFn is called instead of fn, and its +// return value is the return value from Close. +type bodyEOFSignal struct { + body io.ReadCloser + mu sync.Mutex // guards following 4 fields + closed bool // whether Close has been called + rerr error // sticky Read error + fn func(error) // error will be nil on Read io.EOF + earlyCloseFn func() error // optional alt Close func used if io.EOF not seen +} + +func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { + es.mu.Lock() + closed, rerr := es.closed, es.rerr + es.mu.Unlock() + if closed { + return 0, errors.New("http: read on closed response body") + } + if rerr != nil { + return 0, rerr + } + + n, err = es.body.Read(p) + if err != nil { + es.mu.Lock() + defer es.mu.Unlock() + if es.rerr == nil { + es.rerr = err + } + es.condfn(err) + } + return +} + +func (es *bodyEOFSignal) Close() error { + es.mu.Lock() + defer es.mu.Unlock() + if es.closed { + return nil + } + es.closed = true + if es.earlyCloseFn != nil && es.rerr != io.EOF { + return es.earlyCloseFn() + } + err := es.body.Close() + es.condfn(err) + return err +} + +// caller must hold es.mu. +func (es *bodyEOFSignal) condfn(err error) { + if es.fn == nil { + return + } + if err == io.EOF { + err = nil + } + es.fn(err) + es.fn = nil +} + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +type gzipReader struct { + body io.ReadCloser // underlying Response.Body + zr io.Reader // lazily-initialized gzip reader +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + if gz.zr == nil { + gz.zr, err = gzip.NewReader(gz.body) + if err != nil { + return 0, err + } + } + return gz.zr.Read(p) +} + +func (gz *gzipReader) Close() error { + return gz.body.Close() +} + +type readerAndCloser struct { + io.Reader + io.Closer +} + +type tlsHandshakeTimeoutError struct{} + +func (tlsHandshakeTimeoutError) Timeout() bool { return true } +func (tlsHandshakeTimeoutError) Temporary() bool { return true } +func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } + +type noteEOFReader struct { + r io.Reader + sawEOF *bool +} + +func (nr noteEOFReader) Read(p []byte) (n int, err error) { + n, err = nr.r.Read(p) + if err == io.EOF { + *nr.sawEOF = true + } + return +} diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go new file mode 100644 index 000000000..defa63370 --- /dev/null +++ b/src/net/http/transport_test.go @@ -0,0 +1,2324 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for transport.go + +package http_test + +import ( + "bufio" + "bytes" + "compress/gzip" + "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + . "net/http" + "net/http/httptest" + "net/url" + "os" + "runtime" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close +// and then verify that the final 2 responses get errors back. + +// hostPortHandler writes back the client's "host:port". +var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { + if r.FormValue("close") == "true" { + w.Header().Set("Connection", "close") + } + w.Write([]byte(r.RemoteAddr)) +}) + +// testCloseConn is a net.Conn tracked by a testConnSet. +type testCloseConn struct { + net.Conn + set *testConnSet +} + +func (c *testCloseConn) Close() error { + c.set.remove(c) + return c.Conn.Close() +} + +// testConnSet tracks a set of TCP connections and whether they've +// been closed. +type testConnSet struct { + t *testing.T + mu sync.Mutex // guards closed and list + closed map[net.Conn]bool + list []net.Conn // in order created +} + +func (tcs *testConnSet) insert(c net.Conn) { + tcs.mu.Lock() + defer tcs.mu.Unlock() + tcs.closed[c] = false + tcs.list = append(tcs.list, c) +} + +func (tcs *testConnSet) remove(c net.Conn) { + tcs.mu.Lock() + defer tcs.mu.Unlock() + tcs.closed[c] = true +} + +// some tests use this to manage raw tcp connections for later inspection +func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) { + connSet := &testConnSet{ + t: t, + closed: make(map[net.Conn]bool), + } + dial := func(n, addr string) (net.Conn, error) { + c, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + tc := &testCloseConn{c, connSet} + connSet.insert(tc) + return tc, nil + } + return connSet, dial +} + +func (tcs *testConnSet) check(t *testing.T) { + tcs.mu.Lock() + defer tcs.mu.Unlock() + for i := 4; i >= 0; i-- { + for i, c := range tcs.list { + if tcs.closed[c] { + continue + } + if i != 0 { + tcs.mu.Unlock() + time.Sleep(50 * time.Millisecond) + tcs.mu.Lock() + continue + } + t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list)) + } + } +} + +// Two subsequent requests and verify their response is the same. +// The response from the server is our own IP:port +func TestTransportKeepAlives(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + for _, disableKeepAlive := range []bool{false, true} { + tr := &Transport{DisableKeepAlives: disableKeepAlive} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + fetch := func(n int) string { + res, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + + bodiesDiffer := body1 != body2 + if bodiesDiffer != disableKeepAlive { + t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", + disableKeepAlive, bodiesDiffer, body1, body2) + } + } +} + +func TestTransportConnectionCloseOnResponse(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + connSet, testDial := makeTestDial(t) + + for _, connectionClose := range []bool{false, true} { + tr := &Transport{ + Dial: testDial, + } + c := &Client{Transport: tr} + + fetch := func(n int) string { + req := new(Request) + var err error + req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose)) + if err != nil { + t.Fatalf("URL parse error: %v", err) + } + req.Method = "GET" + req.Proto = "HTTP/1.1" + req.ProtoMajor = 1 + req.ProtoMinor = 1 + + res, err := c.Do(req) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) + } + defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + bodiesDiffer := body1 != body2 + if bodiesDiffer != connectionClose { + t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", + connectionClose, bodiesDiffer, body1, body2) + } + + tr.CloseIdleConnections() + } + + connSet.check(t) +} + +func TestTransportConnectionCloseOnRequest(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + connSet, testDial := makeTestDial(t) + + for _, connectionClose := range []bool{false, true} { + tr := &Transport{ + Dial: testDial, + } + c := &Client{Transport: tr} + + fetch := func(n int) string { + req := new(Request) + var err error + req.URL, err = url.Parse(ts.URL) + if err != nil { + t.Fatalf("URL parse error: %v", err) + } + req.Method = "GET" + req.Proto = "HTTP/1.1" + req.ProtoMajor = 1 + req.ProtoMinor = 1 + req.Close = connectionClose + + res, err := c.Do(req) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + bodiesDiffer := body1 != body2 + if bodiesDiffer != connectionClose { + t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", + connectionClose, bodiesDiffer, body1, body2) + } + + tr.CloseIdleConnections() + } + + connSet.check(t) +} + +func TestTransportIdleCacheKeys(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + c := &Client{Transport: tr} + + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) + } + + resp, err := c.Get(ts.URL) + if err != nil { + t.Error(err) + } + ioutil.ReadAll(resp.Body) + + keys := tr.IdleConnKeysForTesting() + if e, g := 1, len(keys); e != g { + t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g) + } + + if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e { + t.Errorf("Expected idle cache key %q; got %q", e, keys[0]) + } + + tr.CloseIdleConnections() + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) + } +} + +// Tests that the HTTP transport re-uses connections when a client +// reads to the end of a response Body without closing it. +func TestTransportReadToEndReusesConn(t *testing.T) { + defer afterTest(t) + const msg = "foobar" + + var addrSeen map[string]int + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + addrSeen[r.RemoteAddr]++ + if r.URL.Path == "/chunked/" { + w.WriteHeader(200) + w.(http.Flusher).Flush() + } else { + w.Header().Set("Content-Type", strconv.Itoa(len(msg))) + w.WriteHeader(200) + } + w.Write([]byte(msg)) + })) + defer ts.Close() + + buf := make([]byte, len(msg)) + + for pi, path := range []string{"/content-length/", "/chunked/"} { + wantLen := []int{len(msg), -1}[pi] + addrSeen = make(map[string]int) + for i := 0; i < 3; i++ { + res, err := http.Get(ts.URL + path) + if err != nil { + t.Errorf("Get %s: %v", path, err) + continue + } + // We want to close this body eventually (before the + // defer afterTest at top runs), but not before the + // len(addrSeen) check at the bottom of this test, + // since Closing this early in the loop would risk + // making connections be re-used for the wrong reason. + defer res.Body.Close() + + if res.ContentLength != int64(wantLen) { + t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen) + } + n, err := res.Body.Read(buf) + if n != len(msg) || err != io.EOF { + t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg)) + } + } + if len(addrSeen) != 1 { + t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen)) + } + } +} + +func TestTransportMaxPerHostIdleConns(t *testing.T) { + defer afterTest(t) + resch := make(chan string) + gotReq := make(chan bool) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + gotReq <- true + msg := <-resch + _, err := w.Write([]byte(msg)) + if err != nil { + t.Fatalf("Write: %v", err) + } + })) + defer ts.Close() + maxIdleConns := 2 + tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConns} + c := &Client{Transport: tr} + + // Start 3 outstanding requests and wait for the server to get them. + // Their responses will hang until we write to resch, though. + donech := make(chan bool) + doReq := func() { + resp, err := c.Get(ts.URL) + if err != nil { + t.Error(err) + return + } + if _, err := ioutil.ReadAll(resp.Body); err != nil { + t.Errorf("ReadAll: %v", err) + return + } + donech <- true + } + go doReq() + <-gotReq + go doReq() + <-gotReq + go doReq() + <-gotReq + + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) + } + + resch <- "res1" + <-donech + keys := tr.IdleConnKeysForTesting() + if e, g := 1, len(keys); e != g { + t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g) + } + cacheKey := "|http|" + ts.Listener.Addr().String() + if keys[0] != cacheKey { + t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0]) + } + if e, g := 1, tr.IdleConnCountForTesting(cacheKey); e != g { + t.Errorf("after first response, expected %d idle conns; got %d", e, g) + } + + resch <- "res2" + <-donech + if e, g := 2, tr.IdleConnCountForTesting(cacheKey); e != g { + t.Errorf("after second response, expected %d idle conns; got %d", e, g) + } + + resch <- "res3" + <-donech + if e, g := maxIdleConns, tr.IdleConnCountForTesting(cacheKey); e != g { + t.Errorf("after third response, still expected %d idle conns; got %d", e, g) + } +} + +func TestTransportServerClosingUnexpectedly(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + tr := &Transport{} + c := &Client{Transport: tr} + + fetch := func(n, retries int) string { + condFatalf := func(format string, arg ...interface{}) { + if retries <= 0 { + t.Fatalf(format, arg...) + } + t.Logf("retrying shortly after expected error: "+format, arg...) + time.Sleep(time.Second / time.Duration(retries)) + } + for retries >= 0 { + retries-- + res, err := c.Get(ts.URL) + if err != nil { + condFatalf("error in req #%d, GET: %v", n, err) + continue + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + condFatalf("error in req #%d, ReadAll: %v", n, err) + continue + } + res.Body.Close() + return string(body) + } + panic("unreachable") + } + + body1 := fetch(1, 0) + body2 := fetch(2, 0) + + ts.CloseClientConnections() // surprise! + + // This test has an expected race. Sleeping for 25 ms prevents + // it on most fast machines, causing the next fetch() call to + // succeed quickly. But if we do get errors, fetch() will retry 5 + // times with some delays between. + time.Sleep(25 * time.Millisecond) + + body3 := fetch(3, 5) + + if body1 != body2 { + t.Errorf("expected body1 and body2 to be equal") + } + if body2 == body3 { + t.Errorf("expected body2 and body3 to be different") + } +} + +// Test for http://golang.org/issue/2616 (appropriate issue number) +// This fails pretty reliably with GOMAXPROCS=100 or something high. +func TestStressSurpriseServerCloses(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping test in short mode") + } + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "5") + w.Header().Set("Content-Type", "text/plain") + w.Write([]byte("Hello")) + w.(Flusher).Flush() + conn, buf, _ := w.(Hijacker).Hijack() + buf.Flush() + conn.Close() + })) + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + c := &Client{Transport: tr} + + // Do a bunch of traffic from different goroutines. Send to activityc + // after each request completes, regardless of whether it failed. + const ( + numClients = 50 + reqsPerClient = 250 + ) + activityc := make(chan bool) + for i := 0; i < numClients; i++ { + go func() { + for i := 0; i < reqsPerClient; i++ { + res, err := c.Get(ts.URL) + if err == nil { + // We expect errors since the server is + // hanging up on us after telling us to + // send more requests, so we don't + // actually care what the error is. + // But we want to close the body in cases + // where we won the race. + res.Body.Close() + } + activityc <- true + } + }() + } + + // Make sure all the request come back, one way or another. + for i := 0; i < numClients*reqsPerClient; i++ { + select { + case <-activityc: + case <-time.After(5 * time.Second): + t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile") + } + } +} + +// TestTransportHeadResponses verifies that we deal with Content-Lengths +// with no bodies properly +func TestTransportHeadResponses(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "HEAD" { + panic("expected HEAD; got " + r.Method) + } + w.Header().Set("Content-Length", "123") + w.WriteHeader(200) + })) + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + c := &Client{Transport: tr} + for i := 0; i < 2; i++ { + res, err := c.Head(ts.URL) + if err != nil { + t.Errorf("error on loop %d: %v", i, err) + continue + } + if e, g := "123", res.Header.Get("Content-Length"); e != g { + t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) + } + if e, g := int64(123), res.ContentLength; e != g { + t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) + } + if all, err := ioutil.ReadAll(res.Body); err != nil { + t.Errorf("loop %d: Body ReadAll: %v", i, err) + } else if len(all) != 0 { + t.Errorf("Bogus body %q", all) + } + } +} + +// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding +// on responses to HEAD requests. +func TestTransportHeadChunkedResponse(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "HEAD" { + panic("expected HEAD; got " + r.Method) + } + w.Header().Set("Transfer-Encoding", "chunked") // client should ignore + w.Header().Set("x-client-ipport", r.RemoteAddr) + w.WriteHeader(200) + })) + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + c := &Client{Transport: tr} + + res1, err := c.Head(ts.URL) + if err != nil { + t.Fatalf("request 1 error: %v", err) + } + res2, err := c.Head(ts.URL) + if err != nil { + t.Fatalf("request 2 error: %v", err) + } + if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 { + t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2) + } +} + +var roundTripTests = []struct { + accept string + expectAccept string + compressed bool +}{ + // Requests with no accept-encoding header use transparent compression + {"", "gzip", false}, + // Requests with other accept-encoding should pass through unmodified + {"foo", "foo", false}, + // Requests with accept-encoding == gzip should be passed through + {"gzip", "gzip", true}, +} + +// Test that the modification made to the Request by the RoundTripper is cleaned up +func TestRoundTripGzip(t *testing.T) { + defer afterTest(t) + const responseBody = "test response body" + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + accept := req.Header.Get("Accept-Encoding") + if expect := req.FormValue("expect_accept"); accept != expect { + t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", + req.FormValue("testnum"), accept, expect) + } + if accept == "gzip" { + rw.Header().Set("Content-Encoding", "gzip") + gz := gzip.NewWriter(rw) + gz.Write([]byte(responseBody)) + gz.Close() + } else { + rw.Header().Set("Content-Encoding", accept) + rw.Write([]byte(responseBody)) + } + })) + defer ts.Close() + + for i, test := range roundTripTests { + // Test basic request (no accept-encoding) + req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil) + if test.accept != "" { + req.Header.Set("Accept-Encoding", test.accept) + } + res, err := DefaultTransport.RoundTrip(req) + var body []byte + if test.compressed { + var r *gzip.Reader + r, err = gzip.NewReader(res.Body) + if err != nil { + t.Errorf("%d. gzip NewReader: %v", i, err) + continue + } + body, err = ioutil.ReadAll(r) + res.Body.Close() + } else { + body, err = ioutil.ReadAll(res.Body) + } + if err != nil { + t.Errorf("%d. Error: %q", i, err) + continue + } + if g, e := string(body), responseBody; g != e { + t.Errorf("%d. body = %q; want %q", i, g, e) + } + if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { + t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e) + } + if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { + t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) + } + } + +} + +func TestTransportGzip(t *testing.T) { + defer afterTest(t) + const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + const nRandBytes = 1024 * 1024 + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + if req.Method == "HEAD" { + if g := req.Header.Get("Accept-Encoding"); g != "" { + t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g) + } + return + } + if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { + t.Errorf("Accept-Encoding = %q, want %q", g, e) + } + rw.Header().Set("Content-Encoding", "gzip") + + var w io.Writer = rw + var buf bytes.Buffer + if req.FormValue("chunked") == "0" { + w = &buf + defer io.Copy(rw, &buf) + defer func() { + rw.Header().Set("Content-Length", strconv.Itoa(buf.Len())) + }() + } + gz := gzip.NewWriter(w) + gz.Write([]byte(testString)) + if req.FormValue("body") == "large" { + io.CopyN(gz, rand.Reader, nRandBytes) + } + gz.Close() + })) + defer ts.Close() + + for _, chunked := range []string{"1", "0"} { + c := &Client{Transport: &Transport{}} + + // First fetch something large, but only read some of it. + res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked) + if err != nil { + t.Fatalf("large get: %v", err) + } + buf := make([]byte, len(testString)) + n, err := io.ReadFull(res.Body, buf) + if err != nil { + t.Fatalf("partial read of large response: size=%d, %v", n, err) + } + if e, g := testString, string(buf); e != g { + t.Errorf("partial read got %q, expected %q", g, e) + } + res.Body.Close() + // Read on the body, even though it's closed + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected error post-closed large Read; got = %d, %v", n, err) + } + + // Then something small. + res, err = c.Get(ts.URL + "/?chunked=" + chunked) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if g, e := string(body), testString; g != e { + t.Fatalf("body = %q; want %q", g, e) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } + + // Read on the body after it's been fully read: + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err) + } + res.Body.Close() + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after Close; got %d, %v", n, err) + } + } + + // And a HEAD request too, because they're always weird. + c := &Client{Transport: &Transport{}} + res, err := c.Head(ts.URL) + if err != nil { + t.Fatalf("Head: %v", err) + } + if res.StatusCode != 200 { + t.Errorf("Head status=%d; want=200", res.StatusCode) + } +} + +func TestTransportProxy(t *testing.T) { + defer afterTest(t) + ch := make(chan string, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ch <- "real server" + })) + defer ts.Close() + proxy := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ch <- "proxy for " + r.URL.String() + })) + defer proxy.Close() + + pu, err := url.Parse(proxy.URL) + if err != nil { + t.Fatal(err) + } + c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}} + c.Head(ts.URL) + got := <-ch + want := "proxy for " + ts.URL + "/" + if got != want { + t.Errorf("want %q, got %q", want, got) + } +} + +// TestTransportGzipRecursive sends a gzip quine and checks that the +// client gets the same value back. This is more cute than anything, +// but checks that we don't recurse forever, and checks that +// Content-Encoding is removed. +func TestTransportGzipRecursive(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "gzip") + w.Write(rgz) + })) + defer ts.Close() + + c := &Client{Transport: &Transport{}} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(body, rgz) { + t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x", + body, rgz) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } +} + +// golang.org/issue/7750: request fails when server replies with +// a short gzip body +func TestTransportGzipShort(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "gzip") + w.Write([]byte{0x1f, 0x8b}) + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + _, err = ioutil.ReadAll(res.Body) + if err == nil { + t.Fatal("Expect an error from reading a body.") + } + if err != io.ErrUnexpectedEOF { + t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err) + } +} + +// tests that persistent goroutine connections shut down when no longer desired. +func TestTransportPersistConnLeak(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } + defer afterTest(t) + gotReqCh := make(chan bool) + unblockCh := make(chan bool) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + gotReqCh <- true + <-unblockCh + w.Header().Set("Content-Length", "0") + w.WriteHeader(204) + })) + defer ts.Close() + + tr := &Transport{} + c := &Client{Transport: tr} + + n0 := runtime.NumGoroutine() + + const numReq = 25 + didReqCh := make(chan bool) + for i := 0; i < numReq; i++ { + go func() { + res, err := c.Get(ts.URL) + didReqCh <- true + if err != nil { + t.Errorf("client fetch error: %v", err) + return + } + res.Body.Close() + }() + } + + // Wait for all goroutines to be stuck in the Handler. + for i := 0; i < numReq; i++ { + <-gotReqCh + } + + nhigh := runtime.NumGoroutine() + + // Tell all handlers to unblock and reply. + for i := 0; i < numReq; i++ { + unblockCh <- true + } + + // Wait for all HTTP clients to be done. + for i := 0; i < numReq; i++ { + <-didReqCh + } + + tr.CloseIdleConnections() + time.Sleep(100 * time.Millisecond) + runtime.GC() + runtime.GC() // even more. + nfinal := runtime.NumGoroutine() + + growth := nfinal - n0 + + // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. + // Previously we were leaking one per numReq. + if int(growth) > 5 { + t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) + t.Error("too many new goroutines") + } +} + +// golang.org/issue/4531: Transport leaks goroutines when +// request.ContentLength is explicitly short +func TestTransportPersistConnLeakShortBody(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + })) + defer ts.Close() + + tr := &Transport{} + c := &Client{Transport: tr} + + n0 := runtime.NumGoroutine() + body := []byte("Hello") + for i := 0; i < 20; i++ { + req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + req.ContentLength = int64(len(body) - 2) // explicitly short + _, err = c.Do(req) + if err == nil { + t.Fatal("Expect an error from writing too long of a body.") + } + } + nhigh := runtime.NumGoroutine() + tr.CloseIdleConnections() + time.Sleep(400 * time.Millisecond) + runtime.GC() + nfinal := runtime.NumGoroutine() + + growth := nfinal - n0 + + // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. + // Previously we were leaking one per numReq. + t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) + if int(growth) > 5 { + t.Error("too many new goroutines") + } +} + +// This used to crash; http://golang.org/issue/3266 +func TestTransportIdleConnCrash(t *testing.T) { + defer afterTest(t) + tr := &Transport{} + c := &Client{Transport: tr} + + unblockCh := make(chan bool, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + <-unblockCh + tr.CloseIdleConnections() + })) + defer ts.Close() + + didreq := make(chan bool) + go func() { + res, err := c.Get(ts.URL) + if err != nil { + t.Error(err) + } else { + res.Body.Close() // returns idle conn + } + didreq <- true + }() + unblockCh <- true + <-didreq +} + +// Test that the transport doesn't close the TCP connection early, +// before the response body has been read. This was a regression +// which sadly lacked a triggering test. The large response body made +// the old race easier to trigger. +func TestIssue3644(t *testing.T) { + defer afterTest(t) + const numFoos = 5000 + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + for i := 0; i < numFoos; i++ { + w.Write([]byte("foo ")) + } + })) + defer ts.Close() + tr := &Transport{} + c := &Client{Transport: tr} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + bs, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if len(bs) != numFoos*len("foo ") { + t.Errorf("unexpected response length") + } +} + +// Test that a client receives a server's reply, even if the server doesn't read +// the entire request body. +func TestIssue3595(t *testing.T) { + defer afterTest(t) + const deniedMsg = "sorry, denied." + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Error(w, deniedMsg, StatusUnauthorized) + })) + defer ts.Close() + tr := &Transport{} + c := &Client{Transport: tr} + res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) + if err != nil { + t.Errorf("Post: %v", err) + return + } + got, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("Body ReadAll: %v", err) + } + if !strings.Contains(string(got), deniedMsg) { + t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) + } +} + +// From http://golang.org/issue/4454 , +// "client fails to handle requests with no body and chunked encoding" +func TestChunkedNoContent(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusNoContent) + })) + defer ts.Close() + + for _, closeBody := range []bool{true, false} { + c := &Client{Transport: &Transport{}} + const n = 4 + for i := 1; i <= n; i++ { + res, err := c.Get(ts.URL) + if err != nil { + t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err) + } else { + if closeBody { + res.Body.Close() + } + } + } + } +} + +func TestTransportConcurrency(t *testing.T) { + defer afterTest(t) + maxProcs, numReqs := 16, 500 + if testing.Short() { + maxProcs, numReqs = 4, 50 + } + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%v", r.FormValue("echo")) + })) + defer ts.Close() + + var wg sync.WaitGroup + wg.Add(numReqs) + + // Due to the Transport's "socket late binding" (see + // idleConnCh in transport.go), the numReqs HTTP requests + // below can finish with a dial still outstanding. To keep + // the leak checker happy, keep track of pending dials and + // wait for them to finish (and be closed or returned to the + // idle pool) before we close idle connections. + SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) + defer SetPendingDialHooks(nil, nil) + + tr := &Transport{} + defer tr.CloseIdleConnections() + + c := &Client{Transport: tr} + reqs := make(chan string) + defer close(reqs) + + for i := 0; i < maxProcs*2; i++ { + go func() { + for req := range reqs { + res, err := c.Get(ts.URL + "/?echo=" + req) + if err != nil { + t.Errorf("error on req %s: %v", req, err) + wg.Done() + continue + } + all, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Errorf("read error on req %s: %v", req, err) + wg.Done() + continue + } + if string(all) != req { + t.Errorf("body of req %s = %q; want %q", req, all, req) + } + res.Body.Close() + wg.Done() + } + }() + } + for i := 0; i < numReqs; i++ { + reqs <- fmt.Sprintf("request-%d", i) + } + wg.Wait() +} + +func TestIssue4191_InfiniteGetTimeout(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } + defer afterTest(t) + const debug = false + mux := NewServeMux() + mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { + io.Copy(w, neverEnding('a')) + }) + ts := httptest.NewServer(mux) + timeout := 100 * time.Millisecond + + client := &Client{ + Transport: &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + conn, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(timeout)) + if debug { + conn = NewLoggingConn("client", conn) + } + return conn, nil + }, + DisableKeepAlives: true, + }, + } + + getFailed := false + nRuns := 5 + if testing.Short() { + nRuns = 1 + } + for i := 0; i < nRuns; i++ { + if debug { + println("run", i+1, "of", nRuns) + } + sres, err := client.Get(ts.URL + "/get") + if err != nil { + if !getFailed { + // Make the timeout longer, once. + getFailed = true + t.Logf("increasing timeout") + i-- + timeout *= 10 + continue + } + t.Errorf("Error issuing GET: %v", err) + break + } + _, err = io.Copy(ioutil.Discard, sres.Body) + if err == nil { + t.Errorf("Unexpected successful copy") + break + } + } + if debug { + println("tests complete; waiting for handlers to finish") + } + ts.Close() +} + +func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7237") + } + defer afterTest(t) + const debug = false + mux := NewServeMux() + mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { + io.Copy(w, neverEnding('a')) + }) + mux.HandleFunc("/put", func(w ResponseWriter, r *Request) { + defer r.Body.Close() + io.Copy(ioutil.Discard, r.Body) + }) + ts := httptest.NewServer(mux) + timeout := 100 * time.Millisecond + + client := &Client{ + Transport: &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + conn, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(timeout)) + if debug { + conn = NewLoggingConn("client", conn) + } + return conn, nil + }, + DisableKeepAlives: true, + }, + } + + getFailed := false + nRuns := 5 + if testing.Short() { + nRuns = 1 + } + for i := 0; i < nRuns; i++ { + if debug { + println("run", i+1, "of", nRuns) + } + sres, err := client.Get(ts.URL + "/get") + if err != nil { + if !getFailed { + // Make the timeout longer, once. + getFailed = true + t.Logf("increasing timeout") + i-- + timeout *= 10 + continue + } + t.Errorf("Error issuing GET: %v", err) + break + } + req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body) + _, err = client.Do(req) + if err == nil { + sres.Body.Close() + t.Errorf("Unexpected successful PUT") + break + } + sres.Body.Close() + } + if debug { + println("tests complete; waiting for handlers to finish") + } + ts.Close() +} + +func TestTransportResponseHeaderTimeout(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping timeout test in -short mode") + } + inHandler := make(chan bool, 1) + mux := NewServeMux() + mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) { + inHandler <- true + }) + mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { + inHandler <- true + time.Sleep(2 * time.Second) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + tr := &Transport{ + ResponseHeaderTimeout: 500 * time.Millisecond, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + tests := []struct { + path string + want int + wantErr string + }{ + {path: "/fast", want: 200}, + {path: "/slow", wantErr: "timeout awaiting response headers"}, + {path: "/fast", want: 200}, + } + for i, tt := range tests { + res, err := c.Get(ts.URL + tt.path) + select { + case <-inHandler: + case <-time.After(5 * time.Second): + t.Errorf("never entered handler for test index %d, %s", i, tt.path) + continue + } + if err != nil { + uerr, ok := err.(*url.Error) + if !ok { + t.Errorf("error is not an url.Error; got: %#v", err) + continue + } + nerr, ok := uerr.Err.(net.Error) + if !ok { + t.Errorf("error does not satisfy net.Error interface; got: %#v", err) + continue + } + if !nerr.Timeout() { + t.Errorf("want timeout error; got: %q", nerr) + continue + } + if strings.Contains(err.Error(), tt.wantErr) { + continue + } + t.Errorf("%d. unexpected error: %v", i, err) + continue + } + if tt.wantErr != "" { + t.Errorf("%d. no error. expected error: %v", i, tt.wantErr) + continue + } + if res.StatusCode != tt.want { + t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want) + } + } +} + +func TestTransportCancelRequest(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + unblockc := make(chan bool) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hello") + w.(Flusher).Flush() // send headers and some body + <-unblockc + })) + defer ts.Close() + defer close(unblockc) + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + req, _ := NewRequest("GET", ts.URL, nil) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + go func() { + time.Sleep(1 * time.Second) + tr.CancelRequest(req) + }() + t0 := time.Now() + body, err := ioutil.ReadAll(res.Body) + d := time.Since(t0) + + if err == nil { + t.Error("expected an error reading the body") + } + if string(body) != "Hello" { + t.Errorf("Body = %q; want Hello", body) + } + if d < 500*time.Millisecond { + t.Errorf("expected ~1 second delay; got %v", d) + } + // Verify no outstanding requests after readLoop/writeLoop + // goroutines shut down. + for tries := 3; tries > 0; tries-- { + n := tr.NumPendingRequestsForTesting() + if n == 0 { + break + } + time.Sleep(100 * time.Millisecond) + if tries == 1 { + t.Errorf("pending requests = %d; want 0", n) + } + } +} + +func TestTransportCancelRequestInDial(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + var logbuf bytes.Buffer + eventLog := log.New(&logbuf, "", 0) + + unblockDial := make(chan bool) + defer close(unblockDial) + + inDial := make(chan bool) + tr := &Transport{ + Dial: func(network, addr string) (net.Conn, error) { + eventLog.Println("dial: blocking") + inDial <- true + <-unblockDial + return nil, errors.New("nope") + }, + } + cl := &Client{Transport: tr} + gotres := make(chan bool) + req, _ := NewRequest("GET", "http://something.no-network.tld/", nil) + go func() { + _, err := cl.Do(req) + eventLog.Printf("Get = %v", err) + gotres <- true + }() + + select { + case <-inDial: + case <-time.After(5 * time.Second): + t.Fatal("timeout; never saw blocking dial") + } + + eventLog.Printf("canceling") + tr.CancelRequest(req) + + select { + case <-gotres: + case <-time.After(5 * time.Second): + panic("hang. events are: " + logbuf.String()) + } + + got := logbuf.String() + want := `dial: blocking +canceling +Get = Get http://something.no-network.tld/: net/http: request canceled while waiting for connection +` + if got != want { + t.Errorf("Got events:\n%s\nWant:\n%s", got, want) + } +} + +// golang.org/issue/3672 -- Client can't close HTTP stream +// Calling Close on a Response.Body used to just read until EOF. +// Now it actually closes the TCP connection. +func TestTransportCloseResponseBody(t *testing.T) { + defer afterTest(t) + writeErr := make(chan error, 1) + msg := []byte("young\n") + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + for { + _, err := w.Write(msg) + if err != nil { + writeErr <- err + return + } + w.(Flusher).Flush() + } + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + req, _ := NewRequest("GET", ts.URL, nil) + defer tr.CancelRequest(req) + + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + + const repeats = 3 + buf := make([]byte, len(msg)*repeats) + want := bytes.Repeat(msg, repeats) + + _, err = io.ReadFull(res.Body, buf) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, want) { + t.Fatalf("read %q; want %q", buf, want) + } + didClose := make(chan error, 1) + go func() { + didClose <- res.Body.Close() + }() + select { + case err := <-didClose: + if err != nil { + t.Errorf("Close = %v", err) + } + case <-time.After(10 * time.Second): + t.Fatal("too long waiting for close") + } + select { + case err := <-writeErr: + if err == nil { + t.Errorf("expected non-nil write error") + } + case <-time.After(10 * time.Second): + t.Fatal("too long waiting for write error") + } +} + +type fooProto struct{} + +func (fooProto) RoundTrip(req *Request) (*Response, error) { + res := &Response{ + Status: "200 OK", + StatusCode: 200, + Header: make(Header), + Body: ioutil.NopCloser(strings.NewReader("You wanted " + req.URL.String())), + } + return res, nil +} + +func TestTransportAltProto(t *testing.T) { + defer afterTest(t) + tr := &Transport{} + c := &Client{Transport: tr} + tr.RegisterProtocol("foo", fooProto{}) + res, err := c.Get("foo://bar.com/path") + if err != nil { + t.Fatal(err) + } + bodyb, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + body := string(bodyb) + if e := "You wanted foo://bar.com/path"; body != e { + t.Errorf("got response %q, want %q", body, e) + } +} + +func TestTransportNoHost(t *testing.T) { + defer afterTest(t) + tr := &Transport{} + _, err := tr.RoundTrip(&Request{ + Header: make(Header), + URL: &url.URL{ + Scheme: "http", + }, + }) + want := "http: no Host in request URL" + if got := fmt.Sprint(err); got != want { + t.Errorf("error = %v; want %q", err, want) + } +} + +func TestTransportSocketLateBinding(t *testing.T) { + defer afterTest(t) + + mux := NewServeMux() + fooGate := make(chan bool, 1) + mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) { + w.Header().Set("foo-ipport", r.RemoteAddr) + w.(Flusher).Flush() + <-fooGate + }) + mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) { + w.Header().Set("bar-ipport", r.RemoteAddr) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + dialGate := make(chan bool, 1) + tr := &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + if <-dialGate { + return net.Dial(n, addr) + } + return nil, errors.New("manually closed") + }, + DisableKeepAlives: false, + } + defer tr.CloseIdleConnections() + c := &Client{ + Transport: tr, + } + + dialGate <- true // only allow one dial + fooRes, err := c.Get(ts.URL + "/foo") + if err != nil { + t.Fatal(err) + } + fooAddr := fooRes.Header.Get("foo-ipport") + if fooAddr == "" { + t.Fatal("No addr on /foo request") + } + time.AfterFunc(200*time.Millisecond, func() { + // let the foo response finish so we can use its + // connection for /bar + fooGate <- true + io.Copy(ioutil.Discard, fooRes.Body) + fooRes.Body.Close() + }) + + barRes, err := c.Get(ts.URL + "/bar") + if err != nil { + t.Fatal(err) + } + barAddr := barRes.Header.Get("bar-ipport") + if barAddr != fooAddr { + t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) + } + barRes.Body.Close() + dialGate <- false +} + +// Issue 2184 +func TestTransportReading100Continue(t *testing.T) { + defer afterTest(t) + + const numReqs = 5 + reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) } + reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) } + + send100Response := func(w *io.PipeWriter, r *io.PipeReader) { + defer w.Close() + defer r.Close() + br := bufio.NewReader(r) + n := 0 + for { + n++ + req, err := ReadRequest(br) + if err == io.EOF { + return + } + if err != nil { + t.Error(err) + return + } + slurp, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Errorf("Server request body slurp: %v", err) + return + } + id := req.Header.Get("Request-Id") + resCode := req.Header.Get("X-Want-Response-Code") + if resCode == "" { + resCode = "100 Continue" + if string(slurp) != reqBody(n) { + t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n)) + } + } + body := fmt.Sprintf("Response number %d", n) + v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s +Date: Thu, 28 Feb 2013 17:55:41 GMT + +HTTP/1.1 200 OK +Content-Type: text/html +Echo-Request-Id: %s +Content-Length: %d + +%s`, resCode, id, len(body), body), "\n", "\r\n", -1)) + w.Write(v) + if id == reqID(numReqs) { + return + } + } + + } + + tr := &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + sr, sw := io.Pipe() // server read/write + cr, cw := io.Pipe() // client read/write + conn := &rwTestConn{ + Reader: cr, + Writer: sw, + closeFunc: func() error { + sw.Close() + cw.Close() + return nil + }, + } + go send100Response(cw, sr) + return conn, nil + }, + DisableKeepAlives: false, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + testResponse := func(req *Request, name string, wantCode int) { + res, err := c.Do(req) + if err != nil { + t.Fatalf("%s: Do: %v", name, err) + } + if res.StatusCode != wantCode { + t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode) + } + if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack { + t.Errorf("%s: response id %q != request id %q", name, idBack, id) + } + _, err = ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("%s: Slurp error: %v", name, err) + } + } + + // Few 100 responses, making sure we're not off-by-one. + for i := 1; i <= numReqs; i++ { + req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i))) + req.Header.Set("Request-Id", reqID(i)) + testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200) + } + + // And some other informational 1xx but non-100 responses, to test + // we return them but don't re-use the connection. + for i := 1; i <= numReqs; i++ { + req, _ := NewRequest("POST", "http://other.tld/", strings.NewReader(reqBody(i))) + req.Header.Set("X-Want-Response-Code", "123 Sesame Street") + testResponse(req, fmt.Sprintf("123, %d/%d", i, numReqs), 123) + } +} + +type proxyFromEnvTest struct { + req string // URL to fetch; blank means "http://example.com" + + env string // HTTP_PROXY + httpsenv string // HTTPS_PROXY + noenv string // NO_RPXY + + want string + wanterr error +} + +func (t proxyFromEnvTest) String() string { + var buf bytes.Buffer + space := func() { + if buf.Len() > 0 { + buf.WriteByte(' ') + } + } + if t.env != "" { + fmt.Fprintf(&buf, "http_proxy=%q", t.env) + } + if t.httpsenv != "" { + space() + fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv) + } + if t.noenv != "" { + space() + fmt.Fprintf(&buf, "no_proxy=%q", t.noenv) + } + req := "http://example.com" + if t.req != "" { + req = t.req + } + space() + fmt.Fprintf(&buf, "req=%q", req) + return strings.TrimSpace(buf.String()) +} + +var proxyFromEnvTests = []proxyFromEnvTest{ + {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"}, + {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"}, + {env: "cache.corp.example.com", want: "http://cache.corp.example.com"}, + {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"}, + {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"}, + {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"}, + + // Don't use secure for http + {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"}, + // Use secure for https. + {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"}, + {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"}, + + {want: "<nil>"}, + + {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"}, + {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "<nil>"}, + {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, + {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"}, + {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, +} + +func TestProxyFromEnvironment(t *testing.T) { + ResetProxyEnv() + for _, tt := range proxyFromEnvTests { + os.Setenv("HTTP_PROXY", tt.env) + os.Setenv("HTTPS_PROXY", tt.httpsenv) + os.Setenv("NO_PROXY", tt.noenv) + ResetCachedEnvironment() + reqURL := tt.req + if reqURL == "" { + reqURL = "http://example.com" + } + req, _ := NewRequest("GET", reqURL, nil) + url, err := ProxyFromEnvironment(req) + if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e { + t.Errorf("%v: got error = %q, want %q", tt, g, e) + continue + } + if got := fmt.Sprintf("%s", url); got != tt.want { + t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want) + } + } +} + +func TestIdleConnChannelLeak(t *testing.T) { + var mu sync.Mutex + var n int + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + mu.Lock() + n++ + mu.Unlock() + })) + defer ts.Close() + + tr := &Transport{ + Dial: func(netw, addr string) (net.Conn, error) { + return net.Dial(netw, ts.Listener.Addr().String()) + }, + } + defer tr.CloseIdleConnections() + + c := &Client{Transport: tr} + + // First, without keep-alives. + for _, disableKeep := range []bool{true, false} { + tr.DisableKeepAlives = disableKeep + for i := 0; i < 5; i++ { + _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i)) + if err != nil { + t.Fatal(err) + } + } + if got := tr.IdleConnChMapSizeForTesting(); got != 0 { + t.Fatalf("ForDisableKeepAlives = %v, map size = %d; want 0", disableKeep, got) + } + } +} + +// Verify the status quo: that the Client.Post function coerces its +// body into a ReadCloser if it's a Closer, and that the Transport +// then closes it. +func TestTransportClosesRequestBody(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(http.HandlerFunc(func(w ResponseWriter, r *Request) { + io.Copy(ioutil.Discard, r.Body) + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + cl := &Client{Transport: tr} + + closes := 0 + + res, err := cl.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if closes != 1 { + t.Errorf("closes = %d; want 1", closes) + } +} + +func TestTransportTLSHandshakeTimeout(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping in short mode") + } + ln := newLocalListener(t) + defer ln.Close() + testdonec := make(chan struct{}) + defer close(testdonec) + + go func() { + c, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + <-testdonec + c.Close() + }() + + getdonec := make(chan struct{}) + go func() { + defer close(getdonec) + tr := &Transport{ + Dial: func(_, _ string) (net.Conn, error) { + return net.Dial("tcp", ln.Addr().String()) + }, + TLSHandshakeTimeout: 250 * time.Millisecond, + } + cl := &Client{Transport: tr} + _, err := cl.Get("https://dummy.tld/") + if err == nil { + t.Error("expected error") + return + } + ue, ok := err.(*url.Error) + if !ok { + t.Errorf("expected url.Error; got %#v", err) + return + } + ne, ok := ue.Err.(net.Error) + if !ok { + t.Errorf("expected net.Error; got %#v", err) + return + } + if !ne.Timeout() { + t.Errorf("expected timeout error; got %v", err) + } + if !strings.Contains(err.Error(), "handshake timeout") { + t.Errorf("expected 'handshake timeout' in error; got %v", err) + } + }() + select { + case <-getdonec: + case <-time.After(5 * time.Second): + t.Error("test timeout; TLS handshake hung?") + } +} + +// Trying to repro golang.org/issue/3514 +func TestTLSServerClosesConnection(t *testing.T) { + defer afterTest(t) + if runtime.GOOS == "windows" { + t.Skip("skipping flaky test on Windows; golang.org/issue/7634") + } + closedc := make(chan bool, 1) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if strings.Contains(r.URL.Path, "/keep-alive-then-die") { + conn, _, _ := w.(Hijacker).Hijack() + conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) + conn.Close() + closedc <- true + return + } + fmt.Fprintf(w, "hello") + })) + defer ts.Close() + tr := &Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + defer tr.CloseIdleConnections() + client := &Client{Transport: tr} + + var nSuccess = 0 + var errs []error + const trials = 20 + for i := 0; i < trials; i++ { + tr.CloseIdleConnections() + res, err := client.Get(ts.URL + "/keep-alive-then-die") + if err != nil { + t.Fatal(err) + } + <-closedc + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != "foo" { + t.Errorf("Got %q, want foo", slurp) + } + + // Now try again and see if we successfully + // pick a new connection. + res, err = client.Get(ts.URL + "/") + if err != nil { + errs = append(errs, err) + continue + } + slurp, err = ioutil.ReadAll(res.Body) + if err != nil { + errs = append(errs, err) + continue + } + nSuccess++ + } + if nSuccess > 0 { + t.Logf("successes = %d of %d", nSuccess, trials) + } else { + t.Errorf("All runs failed:") + } + for _, err := range errs { + t.Logf(" err: %v", err) + } +} + +// byteFromChanReader is an io.Reader that reads a single byte at a +// time from the channel. When the channel is closed, the reader +// returns io.EOF. +type byteFromChanReader chan byte + +func (c byteFromChanReader) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return + } + b, ok := <-c + if !ok { + return 0, io.EOF + } + p[0] = b + return 1, nil +} + +// Verifies that the Transport doesn't reuse a connection in the case +// where the server replies before the request has been fully +// written. We still honor that reply (see TestIssue3595), but don't +// send future requests on the connection because it's then in a +// questionable state. +// golang.org/issue/7569 +func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { + defer afterTest(t) + var sconn struct { + sync.Mutex + c net.Conn + } + var getOkay bool + closeConn := func() { + sconn.Lock() + defer sconn.Unlock() + if sconn.c != nil { + sconn.c.Close() + sconn.c = nil + if !getOkay { + t.Logf("Closed server connection") + } + } + } + defer closeConn() + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method == "GET" { + io.WriteString(w, "bar") + return + } + conn, _, _ := w.(Hijacker).Hijack() + sconn.Lock() + sconn.c = conn + sconn.Unlock() + conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive + go io.Copy(ioutil.Discard, conn) + })) + defer ts.Close() + tr := &Transport{} + defer tr.CloseIdleConnections() + client := &Client{Transport: tr} + + const bodySize = 256 << 10 + finalBit := make(byteFromChanReader, 1) + req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit)) + req.ContentLength = bodySize + res, err := client.Do(req) + if err := wantBody(res, err, "foo"); err != nil { + t.Errorf("POST response: %v", err) + } + donec := make(chan bool) + go func() { + defer close(donec) + res, err = client.Get(ts.URL) + if err := wantBody(res, err, "bar"); err != nil { + t.Errorf("GET response: %v", err) + return + } + getOkay = true // suppress test noise + }() + time.AfterFunc(5*time.Second, closeConn) + select { + case <-donec: + finalBit <- 'x' // unblock the writeloop of the first Post + close(finalBit) + case <-time.After(7 * time.Second): + t.Fatal("timeout waiting for GET request to finish") + } +} + +type errorReader struct { + err error +} + +func (e errorReader) Read(p []byte) (int, error) { return 0, e.err } + +type closerFunc func() error + +func (f closerFunc) Close() error { return f() } + +// Issue 6981 +func TestTransportClosesBodyOnError(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see http://golang.org/issue/7782") + } + defer afterTest(t) + readBody := make(chan error, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := ioutil.ReadAll(r.Body) + readBody <- err + })) + defer ts.Close() + fakeErr := errors.New("fake error") + didClose := make(chan bool, 1) + req, _ := NewRequest("POST", ts.URL, struct { + io.Reader + io.Closer + }{ + io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), errorReader{fakeErr}), + closerFunc(func() error { + select { + case didClose <- true: + default: + } + return nil + }), + }) + res, err := DefaultClient.Do(req) + if res != nil { + defer res.Body.Close() + } + if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) { + t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error()) + } + select { + case err := <-readBody: + if err == nil { + t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'") + } + case <-time.After(5 * time.Second): + t.Error("timeout waiting for server handler to complete") + } + select { + case <-didClose: + default: + t.Errorf("didn't see Body.Close") + } +} + +func TestTransportDialTLS(t *testing.T) { + var mu sync.Mutex // guards following + var gotReq, didDial bool + + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + mu.Lock() + gotReq = true + mu.Unlock() + })) + defer ts.Close() + tr := &Transport{ + DialTLS: func(netw, addr string) (net.Conn, error) { + mu.Lock() + didDial = true + mu.Unlock() + c, err := tls.Dial(netw, addr, &tls.Config{ + InsecureSkipVerify: true, + }) + if err != nil { + return nil, err + } + return c, c.Handshake() + }, + } + defer tr.CloseIdleConnections() + client := &Client{Transport: tr} + res, err := client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + mu.Lock() + if !gotReq { + t.Error("didn't get request") + } + if !didDial { + t.Error("didn't use dial hook") + } +} + +// Test for issue 8755 +// Ensure that if a proxy returns an error, it is exposed by RoundTrip +func TestRoundTripReturnsProxyError(t *testing.T) { + badProxy := func(*http.Request) (*url.URL, error) { + return nil, errors.New("errorMessage") + } + + tr := &Transport{Proxy: badProxy} + + req, _ := http.NewRequest("GET", "http://example.com", nil) + + _, err := tr.RoundTrip(req) + + if err == nil { + t.Error("Expected proxy error to be returned by RoundTrip") + } +} + +// tests that putting an idle conn after a call to CloseIdleConns does return it +func TestTransportCloseIdleConnsThenReturn(t *testing.T) { + tr := &Transport{} + wantIdle := func(when string, n int) bool { + got := tr.IdleConnCountForTesting("|http|example.com") // key used by PutIdleTestConn + if got == n { + return true + } + t.Errorf("%s: idle conns = %d; want %d", when, got, n) + return false + } + wantIdle("start", 0) + if !tr.PutIdleTestConn() { + t.Fatal("put failed") + } + if !tr.PutIdleTestConn() { + t.Fatal("second put failed") + } + wantIdle("after put", 2) + tr.CloseIdleConnections() + if !tr.IsIdleForTesting() { + t.Error("should be idle after CloseIdleConnections") + } + wantIdle("after close idle", 0) + if tr.PutIdleTestConn() { + t.Fatal("put didn't fail") + } + wantIdle("after second put", 0) + + tr.RequestIdleConnChForTesting() // should toggle the transport out of idle mode + if tr.IsIdleForTesting() { + t.Error("shouldn't be idle after RequestIdleConnChForTesting") + } + if !tr.PutIdleTestConn() { + t.Fatal("after re-activation") + } + wantIdle("after final put", 1) +} + +// This tests that an client requesting a content range won't also +// implicitly ask for gzip support. If they want that, they need to do it +// on their own. +// golang.org/issue/8923 +func TestTransportRangeAndGzip(t *testing.T) { + defer afterTest(t) + reqc := make(chan *Request, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + reqc <- r + })) + defer ts.Close() + + req, _ := NewRequest("GET", ts.URL, nil) + req.Header.Set("Range", "bytes=7-11") + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + + select { + case r := <-reqc: + if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + t.Error("Transport advertised gzip support in the Accept header") + } + if r.Header.Get("Range") == "" { + t.Error("no Range in request") + } + case <-time.After(10 * time.Second): + t.Fatal("timeout") + } + res.Body.Close() +} + +func wantBody(res *http.Response, err error, want string) error { + if err != nil { + return err + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("error reading body: %v", err) + } + if string(slurp) != want { + return fmt.Errorf("body = %q; want %q", slurp, want) + } + if err := res.Body.Close(); err != nil { + return fmt.Errorf("body Close = %v", err) + } + return nil +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + t.Fatal(err) + } + return ln +} + +type countCloseReader struct { + n *int + io.Reader +} + +func (cr countCloseReader) Close() error { + (*cr.n)++ + return nil +} + +// rgz is a gzip quine that uncompresses to itself. +var rgz = []byte{ + 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73, + 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0, + 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, + 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, + 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60, + 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, + 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, + 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, + 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16, + 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, + 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, + 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, + 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, + 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, + 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, + 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, + 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, + 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, + 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, + 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, + 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, + 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, + 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, + 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, + 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, + 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06, + 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, + 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, + 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, + 0x00, 0x00, +} diff --git a/src/net/http/triv.go b/src/net/http/triv.go new file mode 100644 index 000000000..232d65089 --- /dev/null +++ b/src/net/http/triv.go @@ -0,0 +1,141 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +package main + +import ( + "bytes" + "expvar" + "flag" + "fmt" + "io" + "log" + "net/http" + "os" + "os/exec" + "strconv" + "sync" +) + +// hello world, the web server +var helloRequests = expvar.NewInt("hello-requests") + +func HelloServer(w http.ResponseWriter, req *http.Request) { + helloRequests.Add(1) + io.WriteString(w, "hello, world!\n") +} + +// Simple counter server. POSTing to it will set the value. +type Counter struct { + mu sync.Mutex // protects n + n int +} + +// This makes Counter satisfy the expvar.Var interface, so we can export +// it directly. +func (ctr *Counter) String() string { + ctr.mu.Lock() + defer ctr.mu.Unlock() + return fmt.Sprintf("%d", ctr.n) +} + +func (ctr *Counter) ServeHTTP(w http.ResponseWriter, req *http.Request) { + ctr.mu.Lock() + defer ctr.mu.Unlock() + switch req.Method { + case "GET": + ctr.n++ + case "POST": + buf := new(bytes.Buffer) + io.Copy(buf, req.Body) + body := buf.String() + if n, err := strconv.Atoi(body); err != nil { + fmt.Fprintf(w, "bad POST: %v\nbody: [%v]\n", err, body) + } else { + ctr.n = n + fmt.Fprint(w, "counter reset\n") + } + } + fmt.Fprintf(w, "counter = %d\n", ctr.n) +} + +// simple flag server +var booleanflag = flag.Bool("boolean", true, "another flag for testing") + +func FlagServer(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + fmt.Fprint(w, "Flags:\n") + flag.VisitAll(func(f *flag.Flag) { + if f.Value.String() != f.DefValue { + fmt.Fprintf(w, "%s = %s [default = %s]\n", f.Name, f.Value.String(), f.DefValue) + } else { + fmt.Fprintf(w, "%s = %s\n", f.Name, f.Value.String()) + } + }) +} + +// simple argument server +func ArgServer(w http.ResponseWriter, req *http.Request) { + for _, s := range os.Args { + fmt.Fprint(w, s, " ") + } +} + +// a channel (just for the fun of it) +type Chan chan int + +func ChanCreate() Chan { + c := make(Chan) + go func(c Chan) { + for x := 0; ; x++ { + c <- x + } + }(c) + return c +} + +func (ch Chan) ServeHTTP(w http.ResponseWriter, req *http.Request) { + io.WriteString(w, fmt.Sprintf("channel send #%d\n", <-ch)) +} + +// exec a program, redirecting output +func DateServer(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("Content-Type", "text/plain; charset=utf-8") + + date, err := exec.Command("/bin/date").Output() + if err != nil { + http.Error(rw, err.Error(), 500) + return + } + rw.Write(date) +} + +func Logger(w http.ResponseWriter, req *http.Request) { + log.Print(req.URL) + http.Error(w, "oops", 404) +} + +var webroot = flag.String("root", os.Getenv("HOME"), "web root directory") + +func main() { + flag.Parse() + + // The counter is published as a variable directly. + ctr := new(Counter) + expvar.Publish("counter", ctr) + http.Handle("/counter", ctr) + http.Handle("/", http.HandlerFunc(Logger)) + http.Handle("/go/", http.StripPrefix("/go/", http.FileServer(http.Dir(*webroot)))) + http.Handle("/chan", ChanCreate()) + http.HandleFunc("/flags", FlagServer) + http.HandleFunc("/args", ArgServer) + http.HandleFunc("/go/hello", HelloServer) + http.HandleFunc("/date", DateServer) + err := http.ListenAndServe(":12345", nil) + if err != nil { + log.Panicln("ListenAndServe:", err) + } +} |