diff options
Diffstat (limited to 'src/pkg/http')
34 files changed, 2290 insertions, 537 deletions
diff --git a/src/pkg/http/cgi/Makefile b/src/pkg/http/cgi/Makefile index 02f6cfc9e..19b1039c2 100644 --- a/src/pkg/http/cgi/Makefile +++ b/src/pkg/http/cgi/Makefile @@ -6,6 +6,7 @@ include ../../../Make.inc TARG=http/cgi GOFILES=\ - cgi.go\ + child.go\ + host.go\ include ../../../Make.pkg diff --git a/src/pkg/http/cgi/child.go b/src/pkg/http/cgi/child.go new file mode 100644 index 000000000..c7d48b9eb --- /dev/null +++ b/src/pkg/http/cgi/child.go @@ -0,0 +1,192 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements CGI from the perspective of a child +// process. + +package cgi + +import ( + "bufio" + "fmt" + "http" + "io" + "io/ioutil" + "os" + "strconv" + "strings" +) + +// Request returns the HTTP request as represented in the current +// environment. This assumes the current program is being run +// by a web server in a CGI environment. +func Request() (*http.Request, os.Error) { + return requestFromEnvironment(envMap(os.Environ())) +} + +func envMap(env []string) map[string]string { + m := make(map[string]string) + for _, kv := range env { + if idx := strings.Index(kv, "="); idx != -1 { + m[kv[:idx]] = kv[idx+1:] + } + } + return m +} + +// These environment variables are manually copied into Request +var skipHeader = map[string]bool{ + "HTTP_HOST": true, + "HTTP_REFERER": true, + "HTTP_USER_AGENT": true, +} + +func requestFromEnvironment(env map[string]string) (*http.Request, os.Error) { + r := new(http.Request) + r.Method = env["REQUEST_METHOD"] + if r.Method == "" { + return nil, os.NewError("cgi: no REQUEST_METHOD in environment") + } + r.Close = true + r.Trailer = http.Header{} + r.Header = http.Header{} + + r.Host = env["HTTP_HOST"] + r.Referer = env["HTTP_REFERER"] + r.UserAgent = env["HTTP_USER_AGENT"] + + // CGI doesn't allow chunked requests, so these should all be accurate: + r.Proto = "HTTP/1.0" + r.ProtoMajor = 1 + r.ProtoMinor = 0 + r.TransferEncoding = nil + + if lenstr := env["CONTENT_LENGTH"]; lenstr != "" { + clen, err := strconv.Atoi64(lenstr) + if err != nil { + return nil, os.NewError("cgi: bad CONTENT_LENGTH in environment: " + lenstr) + } + r.ContentLength = clen + r.Body = ioutil.NopCloser(io.LimitReader(os.Stdin, clen)) + } + + // Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers + for k, v := range env { + if !strings.HasPrefix(k, "HTTP_") || skipHeader[k] { + continue + } + r.Header.Add(strings.Replace(k[5:], "_", "-", -1), v) + } + + // TODO: cookies. parsing them isn't exported, though. + + if r.Host != "" { + // Hostname is provided, so we can reasonably construct a URL, + // even if we have to assume 'http' for the scheme. + r.RawURL = "http://" + r.Host + env["REQUEST_URI"] + url, err := http.ParseURL(r.RawURL) + if err != nil { + return nil, os.NewError("cgi: failed to parse host and REQUEST_URI into a URL: " + r.RawURL) + } + r.URL = url + } + // Fallback logic if we don't have a Host header or the URL + // failed to parse + if r.URL == nil { + r.RawURL = env["REQUEST_URI"] + url, err := http.ParseURL(r.RawURL) + if err != nil { + return nil, os.NewError("cgi: failed to parse REQUEST_URI into a URL: " + r.RawURL) + } + r.URL = url + } + return r, nil +} + +// Serve executes the provided Handler on the currently active CGI +// request, if any. If there's no current CGI environment +// an error is returned. The provided handler may be nil to use +// http.DefaultServeMux. +func Serve(handler http.Handler) os.Error { + req, err := Request() + if err != nil { + return err + } + if handler == nil { + handler = http.DefaultServeMux + } + rw := &response{ + req: req, + header: make(http.Header), + bufw: bufio.NewWriter(os.Stdout), + } + handler.ServeHTTP(rw, req) + if err = rw.bufw.Flush(); err != nil { + return err + } + return nil +} + +type response struct { + req *http.Request + header http.Header + bufw *bufio.Writer + headerSent bool +} + +func (r *response) Flush() { + r.bufw.Flush() +} + +func (r *response) RemoteAddr() string { + return os.Getenv("REMOTE_ADDR") +} + +func (r *response) Header() http.Header { + return r.header +} + +func (r *response) Write(p []byte) (n int, err os.Error) { + if !r.headerSent { + r.WriteHeader(http.StatusOK) + } + return r.bufw.Write(p) +} + +func (r *response) WriteHeader(code int) { + if r.headerSent { + // Note: explicitly using Stderr, as Stdout is our HTTP output. + fmt.Fprintf(os.Stderr, "CGI attempted to write header twice on request for %s", r.req.URL) + return + } + r.headerSent = true + fmt.Fprintf(r.bufw, "Status: %d %s\r\n", code, http.StatusText(code)) + + // Set a default Content-Type + if _, hasType := r.header["Content-Type"]; !hasType { + r.header.Add("Content-Type", "text/html; charset=utf-8") + } + + // TODO: add a method on http.Header to write itself to an io.Writer? + // This is duplicated code. + for k, vv := range r.header { + for _, v := range vv { + v = strings.Replace(v, "\n", "", -1) + v = strings.Replace(v, "\r", "", -1) + v = strings.TrimSpace(v) + fmt.Fprintf(r.bufw, "%s: %s\r\n", k, v) + } + } + r.bufw.Write([]byte("\r\n")) + r.bufw.Flush() +} + +func (r *response) UsingTLS() bool { + // There's apparently a de-facto standard for this. + // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636 + if s := os.Getenv("HTTPS"); s == "on" || s == "ON" || s == "1" { + return true + } + return false +} diff --git a/src/pkg/http/cgi/child_test.go b/src/pkg/http/cgi/child_test.go new file mode 100644 index 000000000..db0e09cf6 --- /dev/null +++ b/src/pkg/http/cgi/child_test.go @@ -0,0 +1,83 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for CGI (the child process perspective) + +package cgi + +import ( + "testing" +) + +func TestRequest(t *testing.T) { + env := map[string]string{ + "REQUEST_METHOD": "GET", + "HTTP_HOST": "example.com", + "HTTP_REFERER": "elsewhere", + "HTTP_USER_AGENT": "goclient", + "HTTP_FOO_BAR": "baz", + "REQUEST_URI": "/path?a=b", + "CONTENT_LENGTH": "123", + } + req, err := requestFromEnvironment(env) + if err != nil { + t.Fatalf("requestFromEnvironment: %v", err) + } + if g, e := req.UserAgent, "goclient"; e != g { + t.Errorf("expected UserAgent %q; got %q", e, g) + } + if g, e := req.Method, "GET"; e != g { + t.Errorf("expected Method %q; got %q", e, g) + } + if g, e := req.Header.Get("User-Agent"), ""; e != g { + // Tests that we don't put recognized headers in the map + t.Errorf("expected User-Agent %q; got %q", e, g) + } + if g, e := req.ContentLength, int64(123); e != g { + t.Errorf("expected ContentLength %d; got %d", e, g) + } + if g, e := req.Referer, "elsewhere"; e != g { + t.Errorf("expected Referer %q; got %q", e, g) + } + if req.Header == nil { + t.Fatalf("unexpected nil Header") + } + if g, e := req.Header.Get("Foo-Bar"), "baz"; e != g { + t.Errorf("expected Foo-Bar %q; got %q", e, g) + } + if g, e := req.RawURL, "http://example.com/path?a=b"; e != g { + t.Errorf("expected RawURL %q; got %q", e, g) + } + if g, e := req.URL.String(), "http://example.com/path?a=b"; e != g { + t.Errorf("expected URL %q; got %q", e, g) + } + if g, e := req.FormValue("a"), "b"; e != g { + t.Errorf("expected FormValue(a) %q; got %q", e, g) + } + if req.Trailer == nil { + t.Errorf("unexpected nil Trailer") + } +} + +func TestRequestWithoutHost(t *testing.T) { + env := map[string]string{ + "HTTP_HOST": "", + "REQUEST_METHOD": "GET", + "REQUEST_URI": "/path?a=b", + "CONTENT_LENGTH": "123", + } + req, err := requestFromEnvironment(env) + if err != nil { + t.Fatalf("requestFromEnvironment: %v", err) + } + if g, e := req.RawURL, "/path?a=b"; e != g { + t.Errorf("expected RawURL %q; got %q", e, g) + } + if req.URL == nil { + t.Fatalf("unexpected nil URL") + } + if g, e := req.URL.String(), "/path?a=b"; e != g { + t.Errorf("expected URL %q; got %q", e, g) + } +} diff --git a/src/pkg/http/cgi/cgi.go b/src/pkg/http/cgi/host.go index dba59efa2..a713d7c3c 100644 --- a/src/pkg/http/cgi/cgi.go +++ b/src/pkg/http/cgi/host.go @@ -2,6 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// This file implements the host side of CGI (being the webserver +// parent process). + // Package cgi implements CGI (Common Gateway Interface) as specified // in RFC 3875. // @@ -12,14 +15,15 @@ package cgi import ( - "encoding/line" + "bufio" + "bytes" "exec" "fmt" "http" "io" "log" "os" - "path" + "path/filepath" "regexp" "strconv" "strings" @@ -29,10 +33,12 @@ var trailingPort = regexp.MustCompile(`:([0-9]+)$`) // Handler runs an executable in a subprocess with a CGI environment. type Handler struct { - Path string // path to the CGI executable - Root string // root URI prefix of handler or empty for "/" + Path string // path to the CGI executable + Root string // root URI prefix of handler or empty for "/" + Env []string // extra environment variables to set, if any Logger *log.Logger // optional log for errors or nil to use log.Print + Args []string // optional arguments to pass to child process } func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { @@ -68,14 +74,29 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { "PATH_INFO=" + pathInfo, "SCRIPT_NAME=" + root, "SCRIPT_FILENAME=" + h.Path, - "REMOTE_ADDR=" + rw.RemoteAddr(), - "REMOTE_HOST=" + rw.RemoteAddr(), + "REMOTE_ADDR=" + req.RemoteAddr, + "REMOTE_HOST=" + req.RemoteAddr, "SERVER_PORT=" + port, } - for k, _ := range req.Header { + if req.TLS != nil { + env = append(env, "HTTPS=on") + } + + if len(req.Cookie) > 0 { + b := new(bytes.Buffer) + for idx, c := range req.Cookie { + if idx > 0 { + b.Write([]byte("; ")) + } + fmt.Fprintf(b, "%s=%s", c.Name, c.Value) + } + env = append(env, "HTTP_COOKIE="+b.String()) + } + + for k, v := range req.Header { k = strings.Map(upperCaseAndUnderscore, k) - env = append(env, "HTTP_"+k+"="+req.Header.Get(k)) + env = append(env, "HTTP_"+k+"="+strings.Join(v, ", ")) } if req.ContentLength > 0 { @@ -89,15 +110,17 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { env = append(env, h.Env...) } - // TODO: use filepath instead of path when available - cwd, pathBase := path.Split(h.Path) + cwd, pathBase := filepath.Split(h.Path) if cwd == "" { cwd = "." } + args := []string{h.Path} + args = append(args, h.Args...) + cmd, err := exec.Run( pathBase, - []string{h.Path}, + args, env, cwd, exec.Pipe, // stdin @@ -119,8 +142,8 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { go io.Copy(cmd.Stdin, req.Body) } - linebody := line.NewReader(cmd.Stdout, 1024) - headers := make(map[string]string) + linebody, _ := bufio.NewReaderSize(cmd.Stdout, 1024) + headers := rw.Header() statusCode := http.StatusOK for { line, isPrefix, err := linebody.ReadLine() @@ -162,12 +185,9 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } statusCode = code default: - headers[header] = val + headers.Add(header, val) } } - for h, v := range headers { - rw.SetHeader(h, v) - } rw.WriteHeader(statusCode) _, err = io.Copy(rw, linebody) diff --git a/src/pkg/http/cgi/cgi_test.go b/src/pkg/http/cgi/host_test.go index daf9a2cb3..e8084b113 100644 --- a/src/pkg/http/cgi/cgi_test.go +++ b/src/pkg/http/cgi/host_test.go @@ -37,6 +37,7 @@ func newRequest(httpreq string) *http.Request { if err != nil { panic("cgi: bogus http request in test: " + httpreq) } + req.RemoteAddr = "1.2.3.4" return req } @@ -47,6 +48,7 @@ func runCgiTest(t *testing.T, h *Handler, httpreq string, expectedMap map[string // Make a map to hold the test map that the CGI returns. m := make(map[string]string) + linesRead := 0 readlines: for { line, err := rw.Body.ReadString('\n') @@ -56,10 +58,12 @@ readlines: case err != nil: t.Fatalf("unexpected error reading from CGI: %v", err) } - line = strings.TrimRight(line, "\r\n") - split := strings.Split(line, "=", 2) + linesRead++ + trimmedLine := strings.TrimRight(line, "\r\n") + split := strings.Split(trimmedLine, "=", 2) if len(split) != 2 { - t.Fatalf("Unexpected %d parts from invalid line: %q", len(split), line) + t.Fatalf("Unexpected %d parts from invalid line number %v: %q; existing map=%v", + len(split), linesRead, line, m) } m[split[0]] = split[1] } @@ -111,10 +115,10 @@ func TestCGIBasicGet(t *testing.T) { } replay := runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) - if expected, got := "text/html", replay.Header.Get("Content-Type"); got != expected { + if expected, got := "text/html", replay.Header().Get("Content-Type"); got != expected { t.Errorf("got a Content-Type of %q; expected %q", got, expected) } - if expected, got := "X-Test-Value", replay.Header.Get("X-Test-Header"); got != expected { + if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected { t.Errorf("got a X-Test-Header of %q; expected %q", got, expected) } } @@ -176,6 +180,28 @@ func TestPathInfoDirRoot(t *testing.T) { runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) } +func TestDupHeaders(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + } + expectedMap := map[string]string{ + "env-REQUEST_URI": "/myscript/bar?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-HTTP_COOKIE": "nom=NOM; yum=YUM", + "env-HTTP_X_FOO": "val1, val2", + } + runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+ + "Cookie: nom=NOM\n"+ + "Cookie: yum=YUM\n"+ + "X-Foo: val1\n"+ + "X-Foo: val2\n"+ + "Host: example.com\n\n", + expectedMap) +} + func TestPathInfoNoRoot(t *testing.T) { if skipTest(t) { return diff --git a/src/pkg/http/cgi/matryoshka_test.go b/src/pkg/http/cgi/matryoshka_test.go new file mode 100644 index 000000000..3e4a6addf --- /dev/null +++ b/src/pkg/http/cgi/matryoshka_test.go @@ -0,0 +1,74 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests a Go CGI program running under a Go CGI host process. +// Further, the two programs are the same binary, just checking +// their environment to figure out what mode to run in. + +package cgi + +import ( + "fmt" + "http" + "os" + "testing" +) + +// This test is a CGI host (testing host.go) that runs its own binary +// as a child process testing the other half of CGI (child.go). +func TestHostingOurselves(t *testing.T) { + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + expectedMap := map[string]string{ + "test": "Hello CGI-in-CGI", + "param-a": "b", + "param-foo": "bar", + "env-GATEWAY_INTERFACE": "CGI/1.1", + "env-HTTP_HOST": "example.com", + "env-PATH_INFO": "", + "env-QUERY_STRING": "foo=bar&a=b", + "env-REMOTE_ADDR": "1.2.3.4", + "env-REMOTE_HOST": "1.2.3.4", + "env-REQUEST_METHOD": "GET", + "env-REQUEST_URI": "/test.go?foo=bar&a=b", + "env-SCRIPT_FILENAME": os.Args[0], + "env-SCRIPT_NAME": "/test.go", + "env-SERVER_NAME": "example.com", + "env-SERVER_PORT": "80", + "env-SERVER_SOFTWARE": "go", + } + replay := runCgiTest(t, h, "GET /test.go?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) + + if expected, got := "text/html; charset=utf-8", replay.Header().Get("Content-Type"); got != expected { + t.Errorf("got a Content-Type of %q; expected %q", got, expected) + } + if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected { + t.Errorf("got a X-Test-Header of %q; expected %q", got, expected) + } +} + +// Note: not actually a test. +func TestBeChildCGIProcess(t *testing.T) { + if os.Getenv("REQUEST_METHOD") == "" { + // Not in a CGI environment; skipping test. + return + } + Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("X-Test-Header", "X-Test-Value") + fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n") + req.ParseForm() + for k, vv := range req.Form { + for _, v := range vv { + fmt.Fprintf(rw, "param-%s=%s\n", k, v) + } + } + for _, kv := range os.Environ() { + fmt.Fprintf(rw, "env-%s\n", kv) + } + })) + os.Exit(0) +} diff --git a/src/pkg/http/cgi/testdata/test.cgi b/src/pkg/http/cgi/testdata/test.cgi index b931b04c5..253589eed 100755 --- a/src/pkg/http/cgi/testdata/test.cgi +++ b/src/pkg/http/cgi/testdata/test.cgi @@ -12,7 +12,7 @@ my $q = CGI->new; my $params = $q->Vars; my $NL = "\r\n"; -$NL = "\n" if 1 || $params->{mode} eq "NL"; +$NL = "\n" if $params->{mode} eq "NL"; my $p = sub { print "$_[0]$NL"; @@ -30,5 +30,7 @@ foreach my $k (sort keys %$params) { } foreach my $k (sort keys %ENV) { - print "env-$k=$ENV{$k}\n"; + my $clean_env = $ENV{$k}; + $clean_env =~ s/[\n\r]//g; + print "env-$k=$clean_env\n"; } diff --git a/src/pkg/http/client.go b/src/pkg/http/client.go index c24eea581..daba3a89b 100644 --- a/src/pkg/http/client.go +++ b/src/pkg/http/client.go @@ -11,6 +11,7 @@ import ( "encoding/base64" "fmt" "io" + "io/ioutil" "os" "strconv" "strings" @@ -20,26 +21,28 @@ import ( // that uses DefaultTransport. // Client is not yet very configurable. type Client struct { - Transport Transport // if nil, DefaultTransport is used + Transport RoundTripper // if nil, DefaultTransport is used } // DefaultClient is the default Client and is used by Get, Head, and Post. var DefaultClient = &Client{} -// Transport is an interface representing the ability to execute a +// RoundTripper is an interface representing the ability to execute a // single HTTP transaction, obtaining the Response for a given Request. -type Transport interface { - // Do executes a single HTTP transaction, returning the Response for the - // request req. Do should not attempt to interpret the response. - // In particular, Do must return err == nil if it obtained a response, - // regardless of the response's HTTP status code. A non-nil err should - // be reserved for failure to obtain a response. Similarly, Do should - // not attempt to handle higher-level protocol details such as redirects, +type RoundTripper interface { + // RoundTrip executes a single HTTP transaction, returning + // the Response for the request req. RoundTrip should not + // attempt to interpret the response. In particular, + // RoundTrip must return err == nil if it obtained a response, + // regardless of the response's HTTP status code. A non-nil + // err should be reserved for failure to obtain a response. + // Similarly, RoundTrip should not attempt to handle + // higher-level protocol details such as redirects, // authentication, or cookies. // - // Transports may modify the request. The request Headers field is - // guaranteed to be initalized. - Do(req *Request) (resp *Response, err os.Error) + // RoundTrip may modify the request. The request Headers field is + // guaranteed to be initialized. + RoundTrip(req *Request) (resp *Response, err os.Error) } // Given a string of the form "host", "host:port", or "[ipv6::address]:port", @@ -54,40 +57,6 @@ type readClose struct { io.Closer } -// matchNoProxy returns true if requests to addr should not use a proxy, -// according to the NO_PROXY or no_proxy environment variable. -func matchNoProxy(addr string) bool { - if len(addr) == 0 { - return false - } - no_proxy := os.Getenv("NO_PROXY") - if len(no_proxy) == 0 { - no_proxy = os.Getenv("no_proxy") - } - if no_proxy == "*" { - return true - } - - addr = strings.ToLower(strings.TrimSpace(addr)) - if hasPort(addr) { - addr = addr[:strings.LastIndex(addr, ":")] - } - - for _, p := range strings.Split(no_proxy, ",", -1) { - p = strings.ToLower(strings.TrimSpace(p)) - if len(p) == 0 { - continue - } - if hasPort(p) { - p = p[:strings.LastIndex(p, ":")] - } - if addr == p || (p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:])) { - return true - } - } - return false -} - // Do sends an HTTP request and returns an HTTP response, following // policy (e.g. redirects, cookies, auth) as configured on the client. // @@ -100,11 +69,7 @@ func (c *Client) Do(req *Request) (resp *Response, err os.Error) { // send issues an HTTP request. Caller should close resp.Body when done reading from it. -// -// TODO: support persistent connections (multiple requests on a single connection). -// send() method is nonpublic because, when we refactor the code for persistent -// connections, it may no longer make sense to have a method with this signature. -func send(req *Request, t Transport) (resp *Response, err os.Error) { +func send(req *Request, t RoundTripper) (resp *Response, err os.Error) { if t == nil { t = DefaultTransport if t == nil { @@ -130,7 +95,7 @@ func send(req *Request, t Transport) (resp *Response, err os.Error) { } req.Header.Set("Authorization", "Basic "+string(encoded)) } - return t.Do(req) + return t.RoundTrip(req) } // True if the specified HTTP status code is one for which the Get utility should @@ -237,7 +202,7 @@ func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response, req.ProtoMajor = 1 req.ProtoMinor = 1 req.Close = true - req.Body = nopCloser{body} + req.Body = ioutil.NopCloser(body) req.Header = Header{ "Content-Type": {bodyType}, } @@ -272,7 +237,7 @@ func (c *Client) PostForm(url string, data map[string]string) (r *Response, err req.ProtoMinor = 1 req.Close = true body := urlencode(data) - req.Body = nopCloser{body} + req.Body = ioutil.NopCloser(body) req.Header = Header{ "Content-Type": {"application/x-www-form-urlencoded"}, "Content-Length": {strconv.Itoa(body.Len())}, @@ -312,9 +277,3 @@ func (c *Client) Head(url string) (r *Response, err os.Error) { } return send(&req, c.Transport) } - -type nopCloser struct { - io.Reader -} - -func (nopCloser) Close() os.Error { return nil } diff --git a/src/pkg/http/client_test.go b/src/pkg/http/client_test.go index c89ecbce2..3a6f83425 100644 --- a/src/pkg/http/client_test.go +++ b/src/pkg/http/client_test.go @@ -4,20 +4,28 @@ // Tests for client.go -package http +package http_test import ( + "fmt" + . "http" + "http/httptest" "io/ioutil" "os" "strings" "testing" ) +var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Last-Modified", "sometime") + fmt.Fprintf(w, "User-agent: go\nDisallow: /something/") +}) + func TestClient(t *testing.T) { - // TODO: add a proper test suite. Current test merely verifies that - // we can retrieve the Google robots.txt file. + ts := httptest.NewServer(robotsTxtHandler) + defer ts.Close() - r, _, err := Get("http://www.google.com/robots.txt") + r, _, err := Get(ts.URL) var b []byte if err == nil { b, err = ioutil.ReadAll(r.Body) @@ -31,7 +39,10 @@ func TestClient(t *testing.T) { } func TestClientHead(t *testing.T) { - r, err := Head("http://www.google.com/robots.txt") + ts := httptest.NewServer(robotsTxtHandler) + defer ts.Close() + + r, err := Head(ts.URL) if err != nil { t.Fatal(err) } @@ -44,7 +55,7 @@ type recordingTransport struct { req *Request } -func (t *recordingTransport) Do(req *Request) (resp *Response, err os.Error) { +func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err os.Error) { t.req = req return nil, os.NewError("dummy impl") } diff --git a/src/pkg/http/cookie.go b/src/pkg/http/cookie.go index ff75c47c9..2bb66e58e 100644 --- a/src/pkg/http/cookie.go +++ b/src/pkg/http/cookie.go @@ -15,65 +15,28 @@ import ( "time" ) -// A note on Version=0 vs. Version=1 cookies +// This implementation is done according to IETF draft-ietf-httpstate-cookie-23, found at // -// The difference between Set-Cookie and Set-Cookie2 is hard to discern from the -// RFCs as it is not stated explicitly. There seem to be three standards -// lingering on the web: Netscape, RFC 2109 (aka Version=0) and RFC 2965 (aka -// Version=1). It seems that Netscape and RFC 2109 are the same thing, hereafter -// Version=0 cookies. -// -// In general, Set-Cookie2 is a superset of Set-Cookie. It has a few new -// attributes like HttpOnly and Secure. To be meticulous, if a server intends -// to use these, it needs to send a Set-Cookie2. However, it is most likely -// most modern browsers will not complain seeing an HttpOnly attribute in a -// Set-Cookie header. -// -// Both RFC 2109 and RFC 2965 use Cookie in the same way - two send cookie -// values from clients to servers - and the allowable attributes seem to be the -// same. -// -// The Cookie2 header is used for a different purpose. If a client suspects that -// the server speaks Version=1 (RFC 2965) then along with the Cookie header -// lines, you can also send: -// -// Cookie2: $Version="1" -// -// in order to suggest to the server that you understand Version=1 cookies. At -// which point the server may continue responding with Set-Cookie2 headers. -// When a client sends the (above) Cookie2 header line, it must be prepated to -// understand incoming Set-Cookie2. -// -// This implementation of cookies supports neither Set-Cookie2 nor Cookie2 -// headers. However, it parses Version=1 Cookies (along with Version=0) as well -// as Set-Cookie headers which utilize the full Set-Cookie2 syntax. - -// TODO(petar): Explicitly forbid parsing of Set-Cookie attributes -// starting with '$', which have been used to hack into broken -// servers using the eventual Request headers containing those -// invalid attributes that may overwrite intended $Version, $Path, -// etc. attributes. -// TODO(petar): Read 'Set-Cookie2' headers and prioritize them over equivalent -// 'Set-Cookie' headers. 'Set-Cookie2' headers are still extremely rare. +// http://tools.ietf.org/html/draft-ietf-httpstate-cookie-23 -// A Cookie represents an RFC 2965 HTTP cookie as sent in -// the Set-Cookie header of an HTTP response or the Cookie header -// of an HTTP request. -// The Set-Cookie2 and Cookie2 headers are unimplemented. +// A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an +// HTTP response or the Cookie header of an HTTP request. type Cookie struct { Name string Value string Path string Domain string - Comment string - Version int Expires time.Time RawExpires string - MaxAge int // Max age in seconds - Secure bool - HttpOnly bool - Raw string - Unparsed []string // Raw text of unparsed attribute-value pairs + + // MaxAge=0 means no 'Max-Age' attribute specified. + // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0' + // MaxAge>0 means Max-Age attribute present and given in seconds + MaxAge int + Secure bool + HttpOnly bool + Raw string + Unparsed []string // Raw text of unparsed attribute-value pairs } // readSetCookies parses all "Set-Cookie" values from @@ -94,16 +57,19 @@ func readSetCookies(h Header) []*Cookie { continue } name, value := parts[0][:j], parts[0][j+1:] - value, err := URLUnescape(value) - if err != nil { + if !isCookieNameValid(name) { + unparsedLines = append(unparsedLines, line) + continue + } + value, success := parseCookieValue(value) + if !success { unparsedLines = append(unparsedLines, line) continue } c := &Cookie{ - Name: name, - Value: value, - MaxAge: -1, // Not specified - Raw: line, + Name: name, + Value: value, + Raw: line, } for i := 1; i < len(parts); i++ { parts[i] = strings.TrimSpace(parts[i]) @@ -114,11 +80,11 @@ func readSetCookies(h Header) []*Cookie { attr, val := parts[i], "" if j := strings.Index(attr, "="); j >= 0 { attr, val = attr[:j], attr[j+1:] - val, err = URLUnescape(val) - if err != nil { - c.Unparsed = append(c.Unparsed, parts[i]) - continue - } + } + val, success = parseCookieValue(val) + if !success { + c.Unparsed = append(c.Unparsed, parts[i]) + continue } switch strings.ToLower(attr) { case "secure": @@ -127,19 +93,20 @@ func readSetCookies(h Header) []*Cookie { case "httponly": c.HttpOnly = true continue - case "comment": - c.Comment = val - continue case "domain": c.Domain = val // TODO: Add domain parsing continue case "max-age": secs, err := strconv.Atoi(val) - if err != nil || secs < 0 { + if err != nil || secs < 0 || secs != 0 && val[0] == '0' { break } - c.MaxAge = secs + if secs <= 0 { + c.MaxAge = -1 + } else { + c.MaxAge = secs + } continue case "expires": c.RawExpires = val @@ -154,13 +121,6 @@ func readSetCookies(h Header) []*Cookie { c.Path = val // TODO: Add path parsing continue - case "version": - c.Version, err = strconv.Atoi(val) - if err != nil { - c.Version = 0 - break - } - continue } c.Unparsed = append(c.Unparsed, parts[i]) } @@ -182,11 +142,7 @@ func writeSetCookies(w io.Writer, kk []*Cookie) os.Error { var b bytes.Buffer for _, c := range kk { b.Reset() - // TODO(petar): c.Value (below) should be unquoted if it is recognized as quoted - fmt.Fprintf(&b, "%s=%s", CanonicalHeaderKey(c.Name), c.Value) - if c.Version > 0 { - fmt.Fprintf(&b, "Version=%d; ", c.Version) - } + fmt.Fprintf(&b, "%s=%s", c.Name, c.Value) if len(c.Path) > 0 { fmt.Fprintf(&b, "; Path=%s", URLEscape(c.Path)) } @@ -196,8 +152,10 @@ func writeSetCookies(w io.Writer, kk []*Cookie) os.Error { if len(c.Expires.Zone) > 0 { fmt.Fprintf(&b, "; Expires=%s", c.Expires.Format(time.RFC1123)) } - if c.MaxAge >= 0 { + if c.MaxAge > 0 { fmt.Fprintf(&b, "; Max-Age=%d", c.MaxAge) + } else if c.MaxAge < 0 { + fmt.Fprintf(&b, "; Max-Age=0") } if c.HttpOnly { fmt.Fprintf(&b, "; HttpOnly") @@ -205,9 +163,6 @@ func writeSetCookies(w io.Writer, kk []*Cookie) os.Error { if c.Secure { fmt.Fprintf(&b, "; Secure") } - if len(c.Comment) > 0 { - fmt.Fprintf(&b, "; Comment=%s", URLEscape(c.Comment)) - } lines = append(lines, "Set-Cookie: "+b.String()+"\r\n") } sort.SortStrings(lines) @@ -235,63 +190,29 @@ func readCookies(h Header) []*Cookie { continue } // Per-line attributes - var lineCookies = make(map[string]string) - var version int - var path string - var domain string - var comment string - var httponly bool + parsedPairs := 0 for i := 0; i < len(parts); i++ { parts[i] = strings.TrimSpace(parts[i]) if len(parts[i]) == 0 { continue } attr, val := parts[i], "" - var err os.Error if j := strings.Index(attr, "="); j >= 0 { attr, val = attr[:j], attr[j+1:] - val, err = URLUnescape(val) - if err != nil { - continue - } } - switch strings.ToLower(attr) { - case "$httponly": - httponly = true - case "$version": - version, err = strconv.Atoi(val) - if err != nil { - version = 0 - continue - } - case "$domain": - domain = val - // TODO: Add domain parsing - case "$path": - path = val - // TODO: Add path parsing - case "$comment": - comment = val - default: - lineCookies[attr] = val + if !isCookieNameValid(attr) { + continue + } + val, success := parseCookieValue(val) + if !success { + continue } + cookies = append(cookies, &Cookie{Name: attr, Value: val}) + parsedPairs++ } - if len(lineCookies) == 0 { + if parsedPairs == 0 { unparsedLines = append(unparsedLines, line) } - for n, v := range lineCookies { - cookies = append(cookies, &Cookie{ - Name: n, - Value: v, - Path: path, - Domain: domain, - Comment: comment, - Version: version, - HttpOnly: httponly, - MaxAge: -1, - Raw: line, - }) - } } h["Cookie"] = unparsedLines, len(unparsedLines) > 0 return cookies @@ -303,28 +224,8 @@ func readCookies(h Header) []*Cookie { // line-length, so it seems safer to place cookies on separate lines. func writeCookies(w io.Writer, kk []*Cookie) os.Error { lines := make([]string, 0, len(kk)) - var b bytes.Buffer for _, c := range kk { - b.Reset() - n := c.Name - if c.Version > 0 { - fmt.Fprintf(&b, "$Version=%d; ", c.Version) - } - // TODO(petar): c.Value (below) should be unquoted if it is recognized as quoted - fmt.Fprintf(&b, "%s=%s", CanonicalHeaderKey(n), c.Value) - if len(c.Path) > 0 { - fmt.Fprintf(&b, "; $Path=%s", URLEscape(c.Path)) - } - if len(c.Domain) > 0 { - fmt.Fprintf(&b, "; $Domain=%s", URLEscape(c.Domain)) - } - if c.HttpOnly { - fmt.Fprintf(&b, "; $HttpOnly") - } - if len(c.Comment) > 0 { - fmt.Fprintf(&b, "; $Comment=%s", URLEscape(c.Comment)) - } - lines = append(lines, "Cookie: "+b.String()+"\r\n") + lines = append(lines, fmt.Sprintf("Cookie: %s=%s\r\n", c.Name, c.Value)) } sort.SortStrings(lines) for _, l := range lines { @@ -334,3 +235,38 @@ func writeCookies(w io.Writer, kk []*Cookie) os.Error { } return nil } + +func unquoteCookieValue(v string) string { + if len(v) > 1 && v[0] == '"' && v[len(v)-1] == '"' { + return v[1 : len(v)-1] + } + return v +} + +func isCookieByte(c byte) bool { + switch true { + case c == 0x21, 0x23 <= c && c <= 0x2b, 0x2d <= c && c <= 0x3a, + 0x3c <= c && c <= 0x5b, 0x5d <= c && c <= 0x7e: + return true + } + return false +} + +func parseCookieValue(raw string) (string, bool) { + raw = unquoteCookieValue(raw) + for i := 0; i < len(raw); i++ { + if !isCookieByte(raw[i]) { + return "", false + } + } + return raw, true +} + +func isCookieNameValid(raw string) bool { + for _, c := range raw { + if !isToken(byte(c)) { + return false + } + } + return true +} diff --git a/src/pkg/http/cookie_test.go b/src/pkg/http/cookie_test.go index 363c841bb..db0997040 100644 --- a/src/pkg/http/cookie_test.go +++ b/src/pkg/http/cookie_test.go @@ -6,6 +6,8 @@ package http import ( "bytes" + "fmt" + "json" "reflect" "testing" ) @@ -16,8 +18,12 @@ var writeSetCookiesTests = []struct { Raw string }{ { - []*Cookie{&Cookie{Name: "cookie-1", Value: "v$1", MaxAge: -1}}, - "Set-Cookie: Cookie-1=v$1\r\n", + []*Cookie{ + &Cookie{Name: "cookie-1", Value: "v$1"}, + &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600}, + }, + "Set-Cookie: cookie-1=v$1\r\n" + + "Set-Cookie: cookie-2=two; Max-Age=3600\r\n", }, } @@ -38,8 +44,8 @@ var writeCookiesTests = []struct { Raw string }{ { - []*Cookie{&Cookie{Name: "cookie-1", Value: "v$1", MaxAge: -1}}, - "Cookie: Cookie-1=v$1\r\n", + []*Cookie{&Cookie{Name: "cookie-1", Value: "v$1"}}, + "Cookie: cookie-1=v$1\r\n", }, } @@ -61,15 +67,23 @@ var readSetCookiesTests = []struct { }{ { Header{"Set-Cookie": {"Cookie-1=v$1"}}, - []*Cookie{&Cookie{Name: "Cookie-1", Value: "v$1", MaxAge: -1, Raw: "Cookie-1=v$1"}}, + []*Cookie{&Cookie{Name: "Cookie-1", Value: "v$1", Raw: "Cookie-1=v$1"}}, }, } +func toJSON(v interface{}) string { + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%#v", v) + } + return string(b) +} + func TestReadSetCookies(t *testing.T) { for i, tt := range readSetCookiesTests { c := readSetCookies(tt.Header) if !reflect.DeepEqual(c, tt.Cookies) { - t.Errorf("#%d readSetCookies: have\n%#v\nwant\n%#v\n", i, c, tt.Cookies) + t.Errorf("#%d readSetCookies: have\n%s\nwant\n%s\n", i, toJSON(c), toJSON(tt.Cookies)) continue } } @@ -81,7 +95,7 @@ var readCookiesTests = []struct { }{ { Header{"Cookie": {"Cookie-1=v$1"}}, - []*Cookie{&Cookie{Name: "Cookie-1", Value: "v$1", MaxAge: -1, Raw: "Cookie-1=v$1"}}, + []*Cookie{&Cookie{Name: "Cookie-1", Value: "v$1"}}, }, } @@ -89,7 +103,7 @@ func TestReadCookies(t *testing.T) { for i, tt := range readCookiesTests { c := readCookies(tt.Header) if !reflect.DeepEqual(c, tt.Cookies) { - t.Errorf("#%d readCookies: have\n%#v\nwant\n%#v\n", i, c, tt.Cookies) + t.Errorf("#%d readCookies: have\n%s\nwant\n%s\n", i, toJSON(c), toJSON(tt.Cookies)) continue } } diff --git a/src/pkg/http/dump.go b/src/pkg/http/dump.go index 73ac97973..306c45bc2 100644 --- a/src/pkg/http/dump.go +++ b/src/pkg/http/dump.go @@ -7,10 +7,10 @@ package http import ( "bytes" "io" + "io/ioutil" "os" ) - // One of the copies, say from b to r2, could be avoided by using a more // elaborate trick where the other copy is made during Request/Response.Write. // This would complicate things too much, given that these functions are for @@ -23,7 +23,7 @@ func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err os.Error) { if err = b.Close(); err != nil { return nil, nil, err } - return nopCloser{&buf}, nopCloser{bytes.NewBuffer(buf.Bytes())}, nil + return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewBuffer(buf.Bytes())), nil } // DumpRequest returns the wire representation of req, diff --git a/src/pkg/http/export_test.go b/src/pkg/http/export_test.go new file mode 100644 index 000000000..47c687760 --- /dev/null +++ b/src/pkg/http/export_test.go @@ -0,0 +1,34 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Bridge package to expose http internals to tests in the http_test +// package. + +package http + +func (t *Transport) IdleConnKeysForTesting() (keys []string) { + keys = make([]string, 0) + t.lk.Lock() + defer t.lk.Unlock() + if t.idleConn == nil { + return + } + for key := range t.idleConn { + keys = append(keys, key) + } + return +} + +func (t *Transport) IdleConnCountForTesting(cacheKey string) int { + t.lk.Lock() + defer t.lk.Unlock() + if t.idleConn == nil { + return 0 + } + conns, ok := t.idleConn[cacheKey] + if !ok { + return 0 + } + return len(conns) +} diff --git a/src/pkg/http/fs.go b/src/pkg/http/fs.go index a4cd7072e..c5efffca9 100644 --- a/src/pkg/http/fs.go +++ b/src/pkg/http/fs.go @@ -72,7 +72,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { return } - f, err := os.Open(name, os.O_RDONLY, 0) + f, err := os.Open(name) if err != nil { // TODO expose actual error? NotFound(w, r) @@ -108,12 +108,12 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { w.WriteHeader(StatusNotModified) return } - w.SetHeader("Last-Modified", time.SecondsToUTC(d.Mtime_ns/1e9).Format(TimeFormat)) + w.Header().Set("Last-Modified", time.SecondsToUTC(d.Mtime_ns/1e9).Format(TimeFormat)) // use contents of index.html for directory, if present if d.IsDirectory() { index := name + filepath.FromSlash(indexPage) - ff, err := os.Open(index, os.O_RDONLY, 0) + ff, err := os.Open(index) if err == nil { defer ff.Close() dd, err := ff.Stat() @@ -134,43 +134,48 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { size := d.Size code := StatusOK - // use extension to find content type. - ext := filepath.Ext(name) - if ctype := mime.TypeByExtension(ext); ctype != "" { - w.SetHeader("Content-Type", ctype) - } else { - // read first chunk to decide between utf-8 text and binary - var buf [1024]byte - n, _ := io.ReadFull(f, buf[:]) - b := buf[:n] - if isText(b) { - w.SetHeader("Content-Type", "text-plain; charset=utf-8") - } else { - w.SetHeader("Content-Type", "application/octet-stream") // generic binary + // If Content-Type isn't set, use the file's extension to find it. + if w.Header().Get("Content-Type") == "" { + ctype := mime.TypeByExtension(filepath.Ext(name)) + if ctype == "" { + // read a chunk to decide between utf-8 text and binary + var buf [1024]byte + n, _ := io.ReadFull(f, buf[:]) + b := buf[:n] + if isText(b) { + ctype = "text-plain; charset=utf-8" + } else { + // generic binary + ctype = "application/octet-stream" + } + f.Seek(0, os.SEEK_SET) // rewind to output whole file } - f.Seek(0, 0) // rewind to output whole file + w.Header().Set("Content-Type", ctype) } // handle Content-Range header. // TODO(adg): handle multiple ranges ranges, err := parseRange(r.Header.Get("Range"), size) - if err != nil || len(ranges) > 1 { + if err == nil && len(ranges) > 1 { + err = os.ErrorString("multiple ranges not supported") + } + if err != nil { Error(w, err.String(), StatusRequestedRangeNotSatisfiable) return } if len(ranges) == 1 { ra := ranges[0] - if _, err := f.Seek(ra.start, 0); err != nil { + if _, err := f.Seek(ra.start, os.SEEK_SET); err != nil { Error(w, err.String(), StatusRequestedRangeNotSatisfiable) return } size = ra.length code = StatusPartialContent - w.SetHeader("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, d.Size)) + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, d.Size)) } - w.SetHeader("Accept-Ranges", "bytes") - w.SetHeader("Content-Length", strconv.Itoa64(size)) + w.Header().Set("Accept-Ranges", "bytes") + w.Header().Set("Content-Length", strconv.Itoa64(size)) w.WriteHeader(code) diff --git a/src/pkg/http/fs_test.go b/src/pkg/http/fs_test.go index a89c76d0b..692b9863e 100644 --- a/src/pkg/http/fs_test.go +++ b/src/pkg/http/fs_test.go @@ -85,6 +85,30 @@ func TestServeFile(t *testing.T) { } } +func TestServeFileContentType(t *testing.T) { + const ctype = "icecream/chocolate" + override := false + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if override { + w.Header().Set("Content-Type", ctype) + } + ServeFile(w, r, "testdata/file") + })) + defer ts.Close() + get := func(want string) { + resp, _, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if h := resp.Header.Get("Content-Type"); h != want { + t.Errorf("Content-Type mismatch: got %q, want %q", h, want) + } + } + get("text-plain; charset=utf-8") + override = true + get(ctype) +} + func getBody(t *testing.T, req Request) (*Response, []byte) { r, err := DefaultClient.Do(&req) if err != nil { diff --git a/src/pkg/http/httptest/recorder.go b/src/pkg/http/httptest/recorder.go index ec7bde8aa..0dd19a617 100644 --- a/src/pkg/http/httptest/recorder.go +++ b/src/pkg/http/httptest/recorder.go @@ -14,20 +14,17 @@ import ( // ResponseRecorder is an implementation of http.ResponseWriter that // records its mutations for later inspection in tests. type ResponseRecorder struct { - Code int // the HTTP response code from WriteHeader - Header http.Header // if non-nil, the headers to populate - Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to - Flushed bool - - FakeRemoteAddr string // the fake RemoteAddr to return, or "" for DefaultRemoteAddr - FakeUsingTLS bool // whether to return true from the UsingTLS method + Code int // the HTTP response code from WriteHeader + HeaderMap http.Header // the HTTP response headers + Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to + Flushed bool } // NewRecorder returns an initialized ResponseRecorder. func NewRecorder() *ResponseRecorder { return &ResponseRecorder{ - Header: http.Header(make(map[string][]string)), - Body: new(bytes.Buffer), + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), } } @@ -35,29 +32,9 @@ func NewRecorder() *ResponseRecorder { // an explicit DefaultRemoteAddr isn't set on ResponseRecorder. const DefaultRemoteAddr = "1.2.3.4" -// RemoteAddr returns the value of rw.FakeRemoteAddr, if set, else -// returns DefaultRemoteAddr. -func (rw *ResponseRecorder) RemoteAddr() string { - if rw.FakeRemoteAddr != "" { - return rw.FakeRemoteAddr - } - return DefaultRemoteAddr -} - -// UsingTLS returns the fake value in rw.FakeUsingTLS -func (rw *ResponseRecorder) UsingTLS() bool { - return rw.FakeUsingTLS -} - -// SetHeader populates rw.Header, if non-nil. -func (rw *ResponseRecorder) SetHeader(k, v string) { - if rw.Header != nil { - if v == "" { - rw.Header.Del(k) - } else { - rw.Header.Set(k, v) - } - } +// Header returns the response headers. +func (rw *ResponseRecorder) Header() http.Header { + return rw.HeaderMap } // Write always succeeds and writes to rw.Body, if not nil. @@ -65,6 +42,9 @@ func (rw *ResponseRecorder) Write(buf []byte) (int, os.Error) { if rw.Body != nil { rw.Body.Write(buf) } + if rw.Code == 0 { + rw.Code = http.StatusOK + } return len(buf), nil } diff --git a/src/pkg/http/httptest/server.go b/src/pkg/http/httptest/server.go index 86c9eb435..8e385d045 100644 --- a/src/pkg/http/httptest/server.go +++ b/src/pkg/http/httptest/server.go @@ -7,9 +7,13 @@ package httptest import ( + "crypto/rand" + "crypto/tls" "fmt" "http" "net" + "os" + "time" ) // A Server is an HTTP server listening on a system-chosen port on the @@ -17,22 +21,69 @@ import ( type Server struct { URL string // base URL of form http://ipaddr:port with no trailing slash Listener net.Listener + TLS *tls.Config // nil if not using using TLS } -// NewServer starts and returns a new Server. -// The caller should call Close when finished, to shut it down. -func NewServer(handler http.Handler) *Server { - ts := new(Server) +// historyListener keeps track of all connections that it's ever +// accepted. +type historyListener struct { + net.Listener + history []net.Conn +} + +func (hs *historyListener) Accept() (c net.Conn, err os.Error) { + c, err = hs.Listener.Accept() + if err == nil { + hs.history = append(hs.history, c) + } + return +} + +func newLocalListener() net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err)) } } - ts.Listener = l + return l +} + +// NewServer starts and returns a new Server. +// The caller should call Close when finished, to shut it down. +func NewServer(handler http.Handler) *Server { + ts := new(Server) + l := newLocalListener() + ts.Listener = &historyListener{l, make([]net.Conn, 0)} ts.URL = "http://" + l.Addr().String() server := &http.Server{Handler: handler} - go server.Serve(l) + go server.Serve(ts.Listener) + return ts +} + +// NewTLSServer starts and returns a new Server using TLS. +// The caller should call Close when finished, to shut it down. +func NewTLSServer(handler http.Handler) *Server { + l := newLocalListener() + ts := new(Server) + + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) + } + + ts.TLS = &tls.Config{ + Rand: rand.Reader, + Time: time.Seconds, + NextProtos: []string{"http/1.1"}, + Certificates: []tls.Certificate{cert}, + } + tlsListener := tls.NewListener(l, ts.TLS) + + ts.Listener = &historyListener{tlsListener, make([]net.Conn, 0)} + ts.URL = "https://" + l.Addr().String() + server := &http.Server{Handler: handler} + go server.Serve(ts.Listener) return ts } @@ -40,3 +91,46 @@ func NewServer(handler http.Handler) *Server { func (s *Server) Close() { s.Listener.Close() } + +// CloseClientConnections closes any currently open HTTP connections +// to the test Server. +func (s *Server) CloseClientConnections() { + hl, ok := s.Listener.(*historyListener) + if !ok { + return + } + for _, conn := range hl.history { + conn.Close() + } +} + +// localhostCert is a PEM-encoded TLS cert with SAN DNS names +// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end +// of ASN.1 time). +var localhostCert = []byte(`-----BEGIN CERTIFICATE----- +MIIBwTCCASugAwIBAgIBADALBgkqhkiG9w0BAQUwADAeFw0xMTAzMzEyMDI1MDda +Fw00OTEyMzEyMzU5NTlaMAAwggCdMAsGCSqGSIb3DQEBAQOCAIwAMIIAhwKCAIB6 +oy4iT42G6qk+GGn5VL5JlnJT6ZG5cqaMNFaNGlIxNb6CPUZLKq2sM3gRaimsktIw +nNAcNwQGHpe1tZo+J/Pl04JTt71Y/TTAxy7OX27aZf1Rpt0SjdZ7vTPnFDPNsHGe +KBKvPt55l2+YKjkZmV7eRevsVbpkNvNGB+T5d4Ge/wIBA6NPME0wDgYDVR0PAQH/ +BAQDAgCgMA0GA1UdDgQGBAQBAgMEMA8GA1UdIwQIMAaABAECAwQwGwYDVR0RBBQw +EoIJMTI3LjAuMC4xggVbOjoxXTALBgkqhkiG9w0BAQUDggCBAHC3gbdvc44vs+wD +g2kONiENnx8WKc0UTGg/TOXS3gaRb+CUIQtHWja65l8rAfclEovjHgZ7gx8brO0W +JuC6p3MUAKsgOssIrrRIx2rpnfcmFVMzguCmrMNVmKUAalw18Yp0F72xYAIitVQl +kJrLdIhBajcJRYu/YGltHQRaXuVt +-----END CERTIFICATE----- +`) + +// localhostKey is the private key for localhostCert. +var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIBkgIBAQKCAIB6oy4iT42G6qk+GGn5VL5JlnJT6ZG5cqaMNFaNGlIxNb6CPUZL +Kq2sM3gRaimsktIwnNAcNwQGHpe1tZo+J/Pl04JTt71Y/TTAxy7OX27aZf1Rpt0S +jdZ7vTPnFDPNsHGeKBKvPt55l2+YKjkZmV7eRevsVbpkNvNGB+T5d4Ge/wIBAwKC +AIBRwh7Bil5Z8cYpZZv7jdQxDvbim7Z7ocRdeDmzZuF2I9RW04QyHHPIIlALnBvI +YeF1veASz1gEFGUjzmbUGqKYSbCoTzXoev+F4bmbRxcX9sOmtslqvhMSHRSzA5NH +aDVI3Hn4wvBVD8gePu8ACWqvPGbCiql11OKCMfjlPn2uuwJAx/24/F5DjXZ6hQQ7 +HxScOxKrpx5WnA9r1wZTltOTZkhRRzuLc21WJeE3M15QUdWi3zZxCKRFoth65HEs +jy9YHQJAnPueRI44tz79b5QqVbeaOMUr7ZCb1Kp0uo6G+ANPLdlfliAupwij2eIz +mHRJOWk0jBtXfRft1McH2H51CpXAyw== +-----END RSA PRIVATE KEY----- +`) diff --git a/src/pkg/http/persist.go b/src/pkg/http/persist.go index 53efd7c8c..b93c5fe48 100644 --- a/src/pkg/http/persist.go +++ b/src/pkg/http/persist.go @@ -211,7 +211,9 @@ type ClientConn struct { nread, nwritten int pipereq map[*Request]uint - pipe textproto.Pipeline + pipe textproto.Pipeline + writeReq func(*Request, io.Writer) os.Error + readRes func(buf *bufio.Reader, method string) (*Response, os.Error) } // NewClientConn returns a new ClientConn reading and writing c. If r is not @@ -220,7 +222,21 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { if r == nil { r = bufio.NewReader(c) } - return &ClientConn{c: c, r: r, pipereq: make(map[*Request]uint)} + return &ClientConn{ + c: c, + r: r, + pipereq: make(map[*Request]uint), + writeReq: (*Request).Write, + readRes: ReadResponse, + } +} + +// NewProxyClientConn works like NewClientConn but writes Requests +// using Request's WriteProxy method. +func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { + cc := NewClientConn(c, r) + cc.writeReq = (*Request).WriteProxy + return cc } // Close detaches the ClientConn and returns the underlying connection as well @@ -281,7 +297,7 @@ func (cc *ClientConn) Write(req *Request) (err os.Error) { } cc.lk.Unlock() - err = req.Write(c) + err = cc.writeReq(req, c) cc.lk.Lock() defer cc.lk.Unlock() if err != nil { @@ -349,7 +365,7 @@ func (cc *ClientConn) Read(req *Request) (resp *Response, err os.Error) { } } - resp, err = ReadResponse(r, req.Method) + resp, err = cc.readRes(r, req.Method) cc.lk.Lock() defer cc.lk.Unlock() if err != nil { diff --git a/src/pkg/http/pprof/pprof.go b/src/pkg/http/pprof/pprof.go index f7db9aab9..bc79e2183 100644 --- a/src/pkg/http/pprof/pprof.go +++ b/src/pkg/http/pprof/pprof.go @@ -18,6 +18,10 @@ // // pprof http://localhost:6060/debug/pprof/heap // +// Or to look at a 30-second CPU profile: +// +// pprof http://localhost:6060/debug/pprof/profile +// package pprof import ( @@ -29,10 +33,12 @@ import ( "runtime/pprof" "strconv" "strings" + "time" ) func init() { http.Handle("/debug/pprof/cmdline", http.HandlerFunc(Cmdline)) + http.Handle("/debug/pprof/profile", http.HandlerFunc(Profile)) http.Handle("/debug/pprof/heap", http.HandlerFunc(Heap)) http.Handle("/debug/pprof/symbol", http.HandlerFunc(Symbol)) } @@ -41,22 +47,46 @@ func init() { // command line, with arguments separated by NUL bytes. // The package initialization registers it as /debug/pprof/cmdline. func Cmdline(w http.ResponseWriter, r *http.Request) { - w.SetHeader("content-type", "text/plain; charset=utf-8") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") fmt.Fprintf(w, strings.Join(os.Args, "\x00")) } // Heap responds with the pprof-formatted heap profile. // The package initialization registers it as /debug/pprof/heap. func Heap(w http.ResponseWriter, r *http.Request) { - w.SetHeader("content-type", "text/plain; charset=utf-8") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") pprof.WriteHeapProfile(w) } +// Profile responds with the pprof-formatted cpu profile. +// The package initialization registers it as /debug/pprof/profile. +func Profile(w http.ResponseWriter, r *http.Request) { + sec, _ := strconv.Atoi64(r.FormValue("seconds")) + if sec == 0 { + sec = 30 + } + + // Set Content Type assuming StartCPUProfile will work, + // because if it does it starts writing. + w.Header().Set("Content-Type", "application/octet-stream") + if err := pprof.StartCPUProfile(w); err != nil { + // StartCPUProfile failed, so no writes yet. + // Can change header back to text content + // and send error code. + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err) + return + } + time.Sleep(sec * 1e9) + pprof.StopCPUProfile() +} + // Symbol looks up the program counters listed in the request, // responding with a table mapping program counters to function names. // The package initialization registers it as /debug/pprof/symbol. func Symbol(w http.ResponseWriter, r *http.Request) { - w.SetHeader("content-type", "text/plain; charset=utf-8") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") // We don't know how many symbols we have, but we // do have symbol information. Pprof only cares whether diff --git a/src/pkg/http/proxy_test.go b/src/pkg/http/proxy_test.go index 0f2ca458f..7050ef5ed 100644 --- a/src/pkg/http/proxy_test.go +++ b/src/pkg/http/proxy_test.go @@ -12,31 +12,33 @@ import ( // TODO(mattn): // test ProxyAuth -var MatchNoProxyTests = []struct { +var UseProxyTests = []struct { host string match bool }{ - {"localhost", true}, // match completely - {"barbaz.net", true}, // match as .barbaz.net - {"foobar.com:443", true}, // have a port but match - {"foofoobar.com", false}, // not match as a part of foobar.com - {"baz.com", false}, // not match as a part of barbaz.com - {"localhost.net", false}, // not match as suffix of address - {"local.localhost", false}, // not match as prefix as address - {"barbarbaz.net", false}, // not match because NO_PROXY have a '.' - {"www.foobar.com", false}, // not match because NO_PROXY is not .foobar.com + {"localhost", false}, // match completely + {"barbaz.net", false}, // match as .barbaz.net + {"foobar.com:443", false}, // have a port but match + {"foofoobar.com", true}, // not match as a part of foobar.com + {"baz.com", true}, // not match as a part of barbaz.com + {"localhost.net", true}, // not match as suffix of address + {"local.localhost", true}, // not match as prefix as address + {"barbarbaz.net", true}, // not match because NO_PROXY have a '.' + {"www.foobar.com", true}, // not match because NO_PROXY is not .foobar.com } -func TestMatchNoProxy(t *testing.T) { +func TestUseProxy(t *testing.T) { oldenv := os.Getenv("NO_PROXY") no_proxy := "foobar.com, .barbaz.net , localhost" os.Setenv("NO_PROXY", no_proxy) defer os.Setenv("NO_PROXY", oldenv) - for _, test := range MatchNoProxyTests { - if matchNoProxy(test.host) != test.match { + tr := &Transport{} + + for _, test := range UseProxyTests { + if tr.useProxy(test.host) != test.match { if test.match { - t.Errorf("matchNoProxy(%v) = %v, want %v", test.host, !test.match, test.match) + t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) } else { t.Errorf("not expected: '%s' shouldn't match as '%s'", test.host, no_proxy) } diff --git a/src/pkg/http/request.go b/src/pkg/http/request.go index d8456bab3..d82894fab 100644 --- a/src/pkg/http/request.go +++ b/src/pkg/http/request.go @@ -11,6 +11,7 @@ package http import ( "bufio" + "crypto/tls" "container/vector" "fmt" "io" @@ -137,6 +138,22 @@ type Request struct { // response has multiple trailer lines with the same key, they will be // concatenated, delimited by commas. Trailer Header + + // RemoteAddr allows HTTP servers and other software to record + // the network address that sent the request, usually for + // logging. This field is not filled in by ReadRequest and + // has no defined format. The HTTP server in this package + // sets RemoteAddr to an "IP:port" address before invoking a + // handler. + RemoteAddr string + + // TLS allows HTTP servers and other software to record + // information about the TLS connection on which the request + // was received. This field is not filled in by ReadRequest. + // The HTTP server in this package sets the field for + // TLS-enabled connections before invoking a handler; + // otherwise it leaves the field nil. + TLS *tls.ConnectionState } // ProtoAtLeast returns whether the HTTP protocol used diff --git a/src/pkg/http/request_test.go b/src/pkg/http/request_test.go index ae1c4e982..19083adf6 100644 --- a/src/pkg/http/request_test.go +++ b/src/pkg/http/request_test.go @@ -2,10 +2,15 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http +package http_test import ( "bytes" + "fmt" + . "http" + "http/httptest" + "io" + "os" "reflect" "regexp" "strings" @@ -141,17 +146,33 @@ func TestMultipartReader(t *testing.T) { } func TestRedirect(t *testing.T) { - const ( - start = "http://google.com/" - endRe = "^http://www\\.google\\.[a-z.]+/$" - ) - var end = regexp.MustCompile(endRe) - r, url, err := Get(start) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + switch r.URL.Path { + case "/": + w.Header().Set("Location", "/foo/") + w.WriteHeader(StatusSeeOther) + case "/foo/": + fmt.Fprintf(w, "foo") + default: + w.WriteHeader(StatusBadRequest) + } + })) + defer ts.Close() + + var end = regexp.MustCompile("/foo/$") + r, url, err := Get(ts.URL) if err != nil { t.Fatal(err) } r.Body.Close() if r.StatusCode != 200 || !end.MatchString(url) { - t.Fatalf("Get(%s) got status %d at %q, want 200 matching %q", start, r.StatusCode, url, endRe) + t.Fatalf("Get got status %d at %q, want 200 matching /foo/$", r.StatusCode, url) } } + +// TODO: stop copy/pasting this around. move to io/ioutil? +type nopCloser struct { + io.Reader +} + +func (nopCloser) Close() os.Error { return nil } diff --git a/src/pkg/http/requestwrite_test.go b/src/pkg/http/requestwrite_test.go index 03a766efd..726baa266 100644 --- a/src/pkg/http/requestwrite_test.go +++ b/src/pkg/http/requestwrite_test.go @@ -6,6 +6,7 @@ package http import ( "bytes" + "io/ioutil" "testing" ) @@ -158,7 +159,7 @@ func TestRequestWrite(t *testing.T) { for i := range reqWriteTests { tt := &reqWriteTests[i] if tt.Body != nil { - tt.Req.Body = nopCloser{bytes.NewBuffer(tt.Body)} + tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(tt.Body)) } var braw bytes.Buffer err := tt.Req.Write(&braw) @@ -173,7 +174,7 @@ func TestRequestWrite(t *testing.T) { } if tt.Body != nil { - tt.Req.Body = nopCloser{bytes.NewBuffer(tt.Body)} + tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(tt.Body)) } var praw bytes.Buffer err = tt.Req.WriteProxy(&praw) diff --git a/src/pkg/http/response.go b/src/pkg/http/response.go index 3d77c5555..1f725ecdd 100644 --- a/src/pkg/http/response.go +++ b/src/pkg/http/response.go @@ -217,13 +217,19 @@ func (resp *Response) Write(w io.Writer) os.Error { func writeSortedHeader(w io.Writer, h Header, exclude map[string]bool) os.Error { keys := make([]string, 0, len(h)) for k := range h { - if !exclude[k] { + if exclude == nil || !exclude[k] { keys = append(keys, k) } } sort.SortStrings(keys) for _, k := range keys { for _, v := range h[k] { + v = strings.Replace(v, "\n", " ", -1) + v = strings.Replace(v, "\r", " ", -1) + v = strings.TrimSpace(v) + if v == "" { + continue + } if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { return err } diff --git a/src/pkg/http/response_test.go b/src/pkg/http/response_test.go index bf63ccb9e..314f05b36 100644 --- a/src/pkg/http/response_test.go +++ b/src/pkg/http/response_test.go @@ -164,6 +164,28 @@ var respTests = []respTest{ "Body here\n", }, + // Chunked response in response to a HEAD request (the "chunked" should + // be ignored, as HEAD responses never have bodies) + { + "HTTP/1.0 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + RequestMethod: "HEAD", + Header: Header{}, + Close: true, + ContentLength: 0, + }, + + "", + }, + // Status line without a Reason-Phrase, but trailing space. // (permitted by RFC 2616) { @@ -229,8 +251,8 @@ func TestReadResponse(t *testing.T) { } func diff(t *testing.T, prefix string, have, want interface{}) { - hv := reflect.NewValue(have).(*reflect.PtrValue).Elem().(*reflect.StructValue) - wv := reflect.NewValue(want).(*reflect.PtrValue).Elem().(*reflect.StructValue) + hv := reflect.NewValue(have).Elem() + wv := reflect.NewValue(want).Elem() if hv.Type() != wv.Type() { t.Errorf("%s: type mismatch %v vs %v", prefix, hv.Type(), wv.Type()) } @@ -238,7 +260,7 @@ func diff(t *testing.T, prefix string, have, want interface{}) { hf := hv.Field(i).Interface() wf := wv.Field(i).Interface() if !reflect.DeepEqual(hf, wf) { - t.Errorf("%s: %s = %v want %v", prefix, hv.Type().(*reflect.StructType).Field(i).Name, hf, wf) + t.Errorf("%s: %s = %v want %v", prefix, hv.Type().Field(i).Name, hf, wf) } } } diff --git a/src/pkg/http/responsewrite_test.go b/src/pkg/http/responsewrite_test.go index 228ed5f7d..de0635da5 100644 --- a/src/pkg/http/responsewrite_test.go +++ b/src/pkg/http/responsewrite_test.go @@ -6,6 +6,7 @@ package http import ( "bytes" + "io/ioutil" "testing" ) @@ -23,7 +24,7 @@ var respWriteTests = []respWriteTest{ ProtoMinor: 0, RequestMethod: "GET", Header: Header{}, - Body: nopCloser{bytes.NewBufferString("abcdef")}, + Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), ContentLength: 6, }, @@ -39,7 +40,7 @@ var respWriteTests = []respWriteTest{ ProtoMinor: 0, RequestMethod: "GET", Header: Header{}, - Body: nopCloser{bytes.NewBufferString("abcdef")}, + Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), ContentLength: -1, }, "HTTP/1.0 200 OK\r\n" + @@ -54,7 +55,7 @@ var respWriteTests = []respWriteTest{ ProtoMinor: 1, RequestMethod: "GET", Header: Header{}, - Body: nopCloser{bytes.NewBufferString("abcdef")}, + Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), ContentLength: 6, TransferEncoding: []string{"chunked"}, Close: true, @@ -65,6 +66,29 @@ var respWriteTests = []respWriteTest{ "Transfer-Encoding: chunked\r\n\r\n" + "6\r\nabcdef\r\n0\r\n\r\n", }, + + // Header value with a newline character (Issue 914). + // Also tests removal of leading and trailing whitespace. + { + Response{ + StatusCode: 204, + ProtoMajor: 1, + ProtoMinor: 1, + RequestMethod: "GET", + Header: Header{ + "Foo": []string{" Bar\nBaz "}, + }, + Body: nil, + ContentLength: 0, + TransferEncoding: []string{"chunked"}, + Close: true, + }, + + "HTTP/1.1 204 No Content\r\n" + + "Connection: close\r\n" + + "Foo: Bar Baz\r\n" + + "\r\n", + }, } func TestResponseWrite(t *testing.T) { @@ -78,7 +102,7 @@ func TestResponseWrite(t *testing.T) { } sraw := braw.String() if sraw != tt.Raw { - t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.Raw, sraw) + t.Errorf("Test %d, expecting:\n%q\nGot:\n%q\n", i, tt.Raw, sraw) continue } } diff --git a/src/pkg/http/serve_test.go b/src/pkg/http/serve_test.go index 86d64bdbb..0142dead9 100644 --- a/src/pkg/http/serve_test.go +++ b/src/pkg/http/serve_test.go @@ -15,6 +15,7 @@ import ( "io/ioutil" "os" "net" + "reflect" "strings" "testing" "time" @@ -144,7 +145,7 @@ func TestConsumingBodyOnNextConn(t *testing.T) { type stringHandler string func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) { - w.SetHeader("Result", string(s)) + w.Header().Set("Result", string(s)) } var handlers = []struct { @@ -174,7 +175,7 @@ func TestHostHandlers(t *testing.T) { ts := httptest.NewServer(nil) defer ts.Close() - conn, err := net.Dial("tcp", "", ts.Listener.Addr().String()) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) } @@ -216,7 +217,7 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { mux.ServeHTTP(resp, req) - if loc, expected := resp.Header.Get("Location"), "/foo.txt"; loc != expected { + if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected { t.Errorf("Expected Location header set to %q; got %q", expected, loc) return } @@ -229,7 +230,8 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { } func TestServerTimeouts(t *testing.T) { - l, err := net.ListenTCP("tcp", &net.TCPAddr{Port: 0}) + // TODO(bradfitz): convert this to use httptest.Server + l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen error: %v", err) } @@ -248,7 +250,9 @@ func TestServerTimeouts(t *testing.T) { url := fmt.Sprintf("http://localhost:%d/", addr.Port) // Hit the HTTP server successfully. - r, _, err := Get(url) + tr := &Transport{DisableKeepAlives: true} // they interfere with this test + c := &Client{Transport: tr} + r, _, err := c.Get(url) if err != nil { t.Fatalf("http Get #1: %v", err) } @@ -261,7 +265,7 @@ func TestServerTimeouts(t *testing.T) { // Slow client that should timeout. t1 := time.Nanoseconds() - conn, err := net.Dial("tcp", "", fmt.Sprintf("localhost:%d", addr.Port)) + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", addr.Port)) if err != nil { t.Fatalf("Dial: %v", err) } @@ -294,8 +298,8 @@ func TestServerTimeouts(t *testing.T) { // TestIdentityResponse verifies that a handler can unset func TestIdentityResponse(t *testing.T) { handler := HandlerFunc(func(rw ResponseWriter, req *Request) { - rw.SetHeader("Content-Length", "3") - rw.SetHeader("Transfer-Encoding", req.FormValue("te")) + rw.Header().Set("Content-Length", "3") + rw.Header().Set("Transfer-Encoding", req.FormValue("te")) switch { case req.FormValue("overwrite") == "1": _, err := rw.Write([]byte("foo TOO LONG")) @@ -303,7 +307,7 @@ func TestIdentityResponse(t *testing.T) { t.Errorf("expected ErrContentLength; got %v", err) } case req.FormValue("underwrite") == "1": - rw.SetHeader("Content-Length", "500") + rw.Header().Set("Content-Length", "500") rw.Write([]byte("too short")) default: rw.Write([]byte("foo")) @@ -333,6 +337,7 @@ func TestIdentityResponse(t *testing.T) { t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)", url, expected, tl, res.TransferEncoding) } + res.Body.Close() } // Verify that ErrContentLength is returned @@ -341,10 +346,9 @@ func TestIdentityResponse(t *testing.T) { if err != nil { t.Fatalf("error with Get of %s: %v", url, err) } - // Verify that the connection is closed when the declared Content-Length // is larger than what the handler wrote. - conn, err := net.Dial("tcp", "", ts.Listener.Addr().String()) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("error dialing: %v", err) } @@ -365,3 +369,250 @@ func TestIdentityResponse(t *testing.T) { expectedSuffix, string(got)) } } + +// TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive. +func TestServeHTTP10Close(t *testing.T) { + s := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "testdata/file") + })) + defer s.Close() + + conn, err := net.Dial("tcp", s.Listener.Addr().String()) + if err != nil { + t.Fatal("dial error:", err) + } + defer conn.Close() + + _, err = fmt.Fprint(conn, "GET / HTTP/1.0\r\n\r\n") + if err != nil { + t.Fatal("print error:", err) + } + + r := bufio.NewReader(conn) + _, err = ReadResponse(r, "GET") + if err != nil { + t.Fatal("ReadResponse error:", err) + } + + success := make(chan bool) + go func() { + select { + case <-time.After(5e9): + t.Fatal("body not closed after 5s") + case <-success: + } + }() + + _, err = ioutil.ReadAll(r) + if err != nil { + t.Fatal("read error:", err) + } + + success <- true +} + +func TestSetsRemoteAddr(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%s", r.RemoteAddr) + })) + defer ts.Close() + + res, _, err := Get(ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + ip := string(body) + if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") { + t.Fatalf("Expected local addr; got %q", ip) + } +} + +func TestChunkedResponseHeaders(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted + fmt.Fprintf(w, "I am a chunked response.") + })) + defer ts.Close() + + res, _, err := Get(ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + if g, e := res.ContentLength, int64(-1); g != e { + t.Errorf("expected ContentLength of %d; got %d", e, g) + } + if g, e := res.TransferEncoding, []string{"chunked"}; !reflect.DeepEqual(g, e) { + t.Errorf("expected TransferEncoding of %v; got %v", e, g) + } + if _, haveCL := res.Header["Content-Length"]; haveCL { + t.Errorf("Unexpected Content-Length") + } +} + +// Test304Responses verifies that 304s don't declare that they're +// chunking in their response headers and aren't allowed to produce +// output. +func Test304Responses(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusNotModified) + _, err := w.Write([]byte("illegal body")) + if err != ErrBodyNotAllowed { + t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) + } + })) + defer ts.Close() + res, _, err := Get(ts.URL) + if err != nil { + t.Error(err) + } + if len(res.TransferEncoding) > 0 { + t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + if len(body) > 0 { + t.Errorf("got unexpected body %q", string(body)) + } +} + +// TestHeadResponses verifies that responses to HEAD requests don't +// declare that they're chunking in their response headers and aren't +// allowed to produce output. +func TestHeadResponses(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := w.Write([]byte("Ignored body")) + if err != ErrBodyNotAllowed { + t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) + } + })) + defer ts.Close() + res, err := Head(ts.URL) + if err != nil { + t.Error(err) + } + if len(res.TransferEncoding) > 0 { + t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + if len(body) > 0 { + t.Errorf("got unexpected body %q", string(body)) + } +} + +func TestTLSServer(t *testing.T) { + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "tls=%v", r.TLS != nil) + })) + defer ts.Close() + if !strings.HasPrefix(ts.URL, "https://") { + t.Fatalf("expected test TLS server to start with https://, got %q", ts.URL) + } + res, _, err := Get(ts.URL) + if err != nil { + t.Error(err) + } + if res == nil { + t.Fatalf("got nil Response") + } + if res.Body == nil { + t.Fatalf("got nil Response.Body") + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + if e, g := "tls=true", string(body); e != g { + t.Errorf("expected body %q; got %q", e, g) + } +} + +type serverExpectTest struct { + contentLength int // of request body + expectation string // e.g. "100-continue" + readBody bool // whether handler should read the body (if false, sends StatusUnauthorized) + expectedResponse string // expected substring in first line of http response +} + +var serverExpectTests = []serverExpectTest{ + // Normal 100-continues, case-insensitive. + {100, "100-continue", true, "100 Continue"}, + {100, "100-cOntInUE", true, "100 Continue"}, + + // No 100-continue. + {100, "", true, "200 OK"}, + + // 100-continue but requesting client to deny us, + // so it never eads the body. + {100, "100-continue", false, "401 Unauthorized"}, + // Likewise without 100-continue: + {100, "", false, "401 Unauthorized"}, + + // Non-standard expectations are failures + {0, "a-pony", false, "417 Expectation Failed"}, + + // Expect-100 requested but no body + {0, "100-continue", true, "400 Bad Request"}, +} + +// Tests that the server responds to the "Expect" request header +// correctly. +func TestServerExpect(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + // Note using r.FormValue("readbody") because for POST + // requests that would read from r.Body, which we only + // conditionally want to do. + if strings.Contains(r.URL.RawPath, "readbody=true") { + ioutil.ReadAll(r.Body) + w.Write([]byte("Hi")) + } else { + w.WriteHeader(StatusUnauthorized) + } + })) + defer ts.Close() + + runTest := func(test serverExpectTest) { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + sendf := func(format string, args ...interface{}) { + _, err := fmt.Fprintf(conn, format, args...) + if err != nil { + t.Fatalf("Error writing %q: %v", format, err) + } + } + go func() { + sendf("POST /?readbody=%v HTTP/1.1\r\n"+ + "Connection: close\r\n"+ + "Content-Length: %d\r\n"+ + "Expect: %s\r\nHost: foo\r\n\r\n", + test.readBody, test.contentLength, test.expectation) + if test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" { + body := strings.Repeat("A", test.contentLength) + sendf(body) + } + }() + bufr := bufio.NewReader(conn) + line, err := bufr.ReadString('\n') + if err != nil { + t.Fatalf("ReadString: %v", err) + } + if !strings.Contains(line, test.expectedResponse) { + t.Errorf("for test %#v got first line=%q", test, line) + } + } + + for _, test := range serverExpectTests { + runTest(test) + } +} diff --git a/src/pkg/http/server.go b/src/pkg/http/server.go index 5d623e696..3291de101 100644 --- a/src/pkg/http/server.go +++ b/src/pkg/http/server.go @@ -6,7 +6,6 @@ // TODO(rsc): // logging -// cgi support // post support package http @@ -49,23 +48,10 @@ type Handler interface { // A ResponseWriter interface is used by an HTTP handler to // construct an HTTP response. type ResponseWriter interface { - // RemoteAddr returns the address of the client that sent the current request - RemoteAddr() string - - // UsingTLS returns true if the client is connected using TLS - UsingTLS() bool - - // SetHeader sets a header line in the eventual response. - // For example, SetHeader("Content-Type", "text/html; charset=utf-8") - // will result in the header line - // - // Content-Type: text/html; charset=utf-8 - // - // being sent. UTF-8 encoded HTML is the default setting for - // Content-Type in this library, so users need not make that - // particular call. Calls to SetHeader after WriteHeader (or Write) - // are ignored. An empty value removes the header if previously set. - SetHeader(string, string) + // Header returns the header map that will be sent by WriteHeader. + // Changing the header after a call to WriteHeader (or Write) has + // no effect. + Header() Header // Write writes the data to the connection as part of an HTTP reply. // If WriteHeader has not yet been called, Write calls WriteHeader(http.StatusOK) @@ -78,42 +64,52 @@ type ResponseWriter interface { // Thus explicit calls to WriteHeader are mainly used to // send error codes. WriteHeader(int) +} +// The Flusher interface is implemented by ResponseWriters that allow +// an HTTP handler to flush buffered data to the client. +// +// Note that even for ResponseWriters that support Flush, +// if the client is connected through an HTTP proxy, +// the buffered data may not reach the client until the response +// completes. +type Flusher interface { // Flush sends any buffered data to the client. Flush() } -// A Hijacker is an HTTP request which be taken over by an HTTP handler. +// The Hijacker interface is implemented by ResponseWriters that allow +// an HTTP handler to take over the connection. type Hijacker interface { // Hijack lets the caller take over the connection. // After a call to Hijack(), the HTTP server library // will not do anything else with the connection. // It becomes the caller's responsibility to manage // and close the connection. - Hijack() (io.ReadWriteCloser, *bufio.ReadWriter, os.Error) + Hijack() (net.Conn, *bufio.ReadWriter, os.Error) } // A conn represents the server side of an HTTP connection. type conn struct { - remoteAddr string // network address of remote side - handler Handler // request handler - rwc io.ReadWriteCloser // i/o connection - buf *bufio.ReadWriter // buffered rwc - hijacked bool // connection has been hijacked by handler - usingTLS bool // a flag indicating connection over TLS + remoteAddr string // network address of remote side + handler Handler // request handler + rwc net.Conn // i/o connection + buf *bufio.ReadWriter // buffered rwc + hijacked bool // connection has been hijacked by handler + tlsState *tls.ConnectionState // or nil when not using TLS } // A response represents the server side of an HTTP response. type response struct { conn *conn - req *Request // request for this response - chunking bool // using chunked transfer encoding for reply body - wroteHeader bool // reply header has been written - wroteContinue bool // 100 Continue response was written - header map[string]string // reply header parameters - written int64 // number of bytes written in body - contentLength int64 // explicitly-declared Content-Length; or -1 - status int // status code passed to WriteHeader + req *Request // request for this response + chunking bool // using chunked transfer encoding for reply body + wroteHeader bool // reply header has been written + wroteContinue bool // 100 Continue response was written + header Header // reply header parameters + written int64 // number of bytes written in body + contentLength int64 // explicitly-declared Content-Length; or -1 + status int // status code passed to WriteHeader // close connection after this reply. set on request and // updated after response from handler if there's a @@ -128,10 +124,15 @@ func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) { c.remoteAddr = rwc.RemoteAddr().String() c.handler = handler c.rwc = rwc - _, c.usingTLS = rwc.(*tls.Conn) br := bufio.NewReader(rwc) bw := bufio.NewWriter(rwc) c.buf = bufio.NewReadWriter(br, bw) + + if tlsConn, ok := rwc.(*tls.Conn); ok { + c.tlsState = new(tls.ConnectionState) + *c.tlsState = tlsConn.ConnectionState() + } + return c, nil } @@ -171,35 +172,21 @@ func (c *conn) readRequest() (w *response, err os.Error) { return nil, err } + req.RemoteAddr = c.remoteAddr + req.TLS = c.tlsState + w = new(response) w.conn = c w.req = req - w.header = make(map[string]string) + w.header = make(Header) w.contentLength = -1 - - // Expect 100 Continue support - if req.expectsContinue() && req.ProtoAtLeast(1, 1) { - // Wrap the Body reader with one that replies on the connection - req.Body = &expectContinueReader{readCloser: req.Body, resp: w} - } return w, nil } -// UsingTLS implements the ResponseWriter.UsingTLS -func (w *response) UsingTLS() bool { - return w.conn.usingTLS -} - -// RemoteAddr implements the ResponseWriter.RemoteAddr method -func (w *response) RemoteAddr() string { return w.conn.remoteAddr } - -// SetHeader implements the ResponseWriter.SetHeader method -// An empty value removes the header from the map. -func (w *response) SetHeader(hdr, val string) { - w.header[CanonicalHeaderKey(hdr)] = val, val != "" +func (w *response) Header() Header { + return w.header } -// WriteHeader implements the ResponseWriter.WriteHeader method func (w *response) WriteHeader(code int) { if w.conn.hijacked { log.Print("http: response.WriteHeader on hijacked connection") @@ -214,55 +201,55 @@ func (w *response) WriteHeader(code int) { if code == StatusNotModified { // Must not have body. for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { - if w.header[header] != "" { + if w.header.Get(header) != "" { // TODO: return an error if WriteHeader gets a return parameter // or set a flag on w to make future Writes() write an error page? // for now just log and drop the header. log.Printf("http: StatusNotModified response with header %q defined", header) - w.header[header] = "", false + w.header.Del(header) } } } else { // Default output is HTML encoded in UTF-8. - if w.header["Content-Type"] == "" { - w.SetHeader("Content-Type", "text/html; charset=utf-8") + if w.header.Get("Content-Type") == "" { + w.header.Set("Content-Type", "text/html; charset=utf-8") } } - if w.header["Date"] == "" { - w.SetHeader("Date", time.UTC().Format(TimeFormat)) + if w.header.Get("Date") == "" { + w.Header().Set("Date", time.UTC().Format(TimeFormat)) } // Check for a explicit (and valid) Content-Length header. var hasCL bool var contentLength int64 - if clenStr, ok := w.header["Content-Length"]; ok { + if clenStr := w.header.Get("Content-Length"); clenStr != "" { var err os.Error contentLength, err = strconv.Atoi64(clenStr) if err == nil { hasCL = true } else { log.Printf("http: invalid Content-Length of %q sent", clenStr) - w.SetHeader("Content-Length", "") + w.header.Del("Content-Length") } } - te, hasTE := w.header["Transfer-Encoding"] + te := w.header.Get("Transfer-Encoding") + hasTE := te != "" if hasCL && hasTE && te != "identity" { // TODO: return an error if WriteHeader gets a return parameter // For now just ignore the Content-Length. log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", te, contentLength) - w.SetHeader("Content-Length", "") + w.header.Del("Content-Length") hasCL = false } - if w.req.Method == "HEAD" { + if w.req.Method == "HEAD" || code == StatusNotModified { // do nothing } else if hasCL { - w.chunking = false w.contentLength = contentLength - w.SetHeader("Transfer-Encoding", "") + w.header.Del("Transfer-Encoding") } else if w.req.ProtoAtLeast(1, 1) { // HTTP/1.1 or greater: use chunked transfer encoding // to avoid closing the connection at EOF. @@ -270,20 +257,19 @@ func (w *response) WriteHeader(code int) { // might have set. Deal with that as need arises once we have a valid // use case. w.chunking = true - w.SetHeader("Transfer-Encoding", "chunked") + w.header.Set("Transfer-Encoding", "chunked") } else { // HTTP version < 1.1: cannot do chunked transfer // encoding and we don't know the Content-Length so // signal EOF by closing connection. w.closeAfterReply = true - w.chunking = false // redundant - w.SetHeader("Transfer-Encoding", "") // in case already set + w.header.Del("Transfer-Encoding") // in case already set } if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { _, connectionHeaderSet := w.header["Connection"] if !connectionHeaderSet { - w.SetHeader("Connection", "keep-alive") + w.header.Set("Connection", "keep-alive") } } else if !w.req.ProtoAtLeast(1, 1) { // Client did not ask to keep connection alive. @@ -292,7 +278,7 @@ func (w *response) WriteHeader(code int) { // Cannot use Content-Length with non-identity Transfer-Encoding. if w.chunking { - w.SetHeader("Content-Length", "") + w.header.Del("Content-Length") } if !w.req.ProtoAtLeast(1, 0) { return @@ -307,13 +293,10 @@ func (w *response) WriteHeader(code int) { text = "status code " + codestring } io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n") - for k, v := range w.header { - io.WriteString(w.conn.buf, k+": "+v+"\r\n") - } + writeSortedHeader(w.conn.buf, w.header, nil) io.WriteString(w.conn.buf, "\r\n") } -// Write implements the ResponseWriter.Write method func (w *response) Write(data []byte) (n int, err os.Error) { if w.conn.hijacked { log.Print("http: response.Write on hijacked connection") @@ -388,7 +371,7 @@ func errorKludge(w *response) { msg += " would ignore this error page if this text weren't here.\n" // Is it text? ("Content-Type" is always in the map) - baseType := strings.Split(w.header["Content-Type"], ";", 2)[0] + baseType := strings.Split(w.header.Get("Content-Type"), ";", 2)[0] switch baseType { case "text/html": io.WriteString(w, "<!-- ") @@ -408,8 +391,8 @@ func (w *response) finishRequest() { // If this was an HTTP/1.0 request with keep-alive and we sent a Content-Length // back, we can make this a keep-alive response ... if w.req.wantsHttp10KeepAlive() { - _, sentLength := w.header["Content-Length"] - if sentLength && w.header["Connection"] == "keep-alive" { + sentLength := w.header.Get("Content-Length") != "" + if sentLength && w.header.Get("Connection") == "keep-alive" { w.closeAfterReply = false } } @@ -431,7 +414,6 @@ func (w *response) finishRequest() { } } -// Flush implements the ResponseWriter.Flush method. func (w *response) Flush() { if !w.wroteHeader { w.WriteHeader(StatusOK) @@ -458,6 +440,38 @@ func (c *conn) serve() { if err != nil { break } + + // Expect 100 Continue support + req := w.req + if req.expectsContinue() { + if req.ProtoAtLeast(1, 1) { + // Wrap the Body reader with one that replies on the connection + req.Body = &expectContinueReader{readCloser: req.Body, resp: w} + } + if req.ContentLength == 0 { + w.Header().Set("Connection", "close") + w.WriteHeader(StatusBadRequest) + break + } + req.Header.Del("Expect") + } else if req.Header.Get("Expect") != "" { + // TODO(bradfitz): let ServeHTTP handlers handle + // requests with non-standard expectation[s]? Seems + // theoretical at best, and doesn't fit into the + // current ServeHTTP model anyway. We'd need to + // make the ResponseWriter an optional + // "ExpectReplier" interface or something. + // + // For now we'll just obey RFC 2616 14.20 which says + // "If a server receives a request containing an + // Expect field that includes an expectation- + // extension that it does not support, it MUST + // respond with a 417 (Expectation Failed) status." + w.Header().Set("Connection", "close") + w.WriteHeader(StatusExpectationFailed) + break + } + // HTTP cannot have multiple simultaneous active requests.[*] // Until the server replies to this request, it can't read another, // so we might as well run the handler in this goroutine. @@ -475,8 +489,9 @@ func (c *conn) serve() { c.close() } -// Hijack impements the ResponseWriter.Hijack method. -func (w *response) Hijack() (rwc io.ReadWriteCloser, buf *bufio.ReadWriter, err os.Error) { +// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter +// and a Hijacker. +func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err os.Error) { if w.conn.hijacked { return nil, nil, ErrHijacked } @@ -503,7 +518,7 @@ func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { // Error replies to the request with the specified error message and HTTP code. func Error(w ResponseWriter, error string, code int) { - w.SetHeader("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.WriteHeader(code) fmt.Fprintln(w, error) } @@ -556,7 +571,7 @@ func Redirect(w ResponseWriter, r *Request, url string, code int) { } } - w.SetHeader("Location", url) + w.Header().Set("Location", url) w.WriteHeader(code) // RFC2616 recommends that a short note "SHOULD" be included in the @@ -679,7 +694,7 @@ func (mux *ServeMux) match(path string) Handler { func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { // Clean path to canonical form and redirect. if p := cleanPath(r.URL.Path); p != r.URL.Path { - w.SetHeader("Location", p) + w.Header().Set("Location", p) w.WriteHeader(StatusMovedPermanently) return } @@ -832,7 +847,7 @@ func ListenAndServe(addr string, handler Handler) os.Error { // ) // // func handler(w http.ResponseWriter, req *http.Request) { -// w.SetHeader("Content-Type", "text/plain") +// w.Header().Set("Content-Type", "text/plain") // w.Write([]byte("This is an example server.\n")) // } // diff --git a/src/pkg/http/transfer.go b/src/pkg/http/transfer.go index 996e28973..41614f144 100644 --- a/src/pkg/http/transfer.go +++ b/src/pkg/http/transfer.go @@ -215,7 +215,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err os.Error) { } // Transfer encoding, content length - t.TransferEncoding, err = fixTransferEncoding(t.Header) + t.TransferEncoding, err = fixTransferEncoding(t.RequestMethod, t.Header) if err != nil { return err } @@ -289,13 +289,20 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err os.Error) { func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } // Sanitize transfer encoding -func fixTransferEncoding(header Header) ([]string, os.Error) { +func fixTransferEncoding(requestMethod string, header Header) ([]string, os.Error) { raw, present := header["Transfer-Encoding"] if !present { return nil, nil } header["Transfer-Encoding"] = nil, false + + // Head responses have no bodies, so the transfer encoding + // should be ignored. + if requestMethod == "HEAD" { + return nil, nil + } + encodings := strings.Split(raw[0], ",", -1) te := make([]string, 0, len(encodings)) // TODO: Even though we only support "identity" and "chunked" diff --git a/src/pkg/http/transport.go b/src/pkg/http/transport.go index 78d316a55..7fa37af3b 100644 --- a/src/pkg/http/transport.go +++ b/src/pkg/http/transport.go @@ -6,9 +6,12 @@ package http import ( "bufio" + "compress/gzip" "crypto/tls" "encoding/base64" "fmt" + "io" + "log" "net" "os" "strings" @@ -20,46 +23,109 @@ import ( // each call to Do and uses HTTP proxies as directed by the // $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy) // environment variables. -var DefaultTransport Transport = &transport{} - -// transport implements Tranport for the default case, using TCP -// connections to either the host or a proxy, serving http or https -// schemes. In the future this may become public and support options -// on keep-alive connection duration, pipelining controls, etc. For -// now this is simply a port of the old Go code client code to the -// Transport interface. -type transport struct { - // TODO: keep-alives, pipelining, etc using a map from - // scheme/host to a connection. Something like: - l sync.Mutex - hostConn map[string]*ClientConn -} - -func (ct *transport) Do(req *Request) (resp *Response, err os.Error) { +var DefaultTransport RoundTripper = &Transport{} + +// DefaultMaxIdleConnsPerHost is the default value of Transport's +// MaxIdleConnsPerHost. +const DefaultMaxIdleConnsPerHost = 2 + +// Transport is an implementation of RoundTripper that supports http, +// https, and http proxies (for either http or https with CONNECT). +// Transport can also cache connections for future re-use. +type Transport struct { + lk sync.Mutex + idleConn map[string][]*persistConn + + // TODO: tunable on global max cached connections + // TODO: tunable on timeout on cached connections + // TODO: optional pipelining + + IgnoreEnvironment bool // don't look at environment variables for proxy configuration + DisableKeepAlives bool + DisableCompression bool + + // MaxIdleConnsPerHost, if non-zero, controls the maximum idle + // (keep-alive) to keep to keep per-host. If zero, + // DefaultMaxIdleConnsPerHost is used. + MaxIdleConnsPerHost int +} + +// RoundTrip implements the RoundTripper interface. +func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { + if req.URL == nil { + if req.URL, err = ParseURL(req.RawURL); err != nil { + return + } + } if req.URL.Scheme != "http" && req.URL.Scheme != "https" { return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} } - addr := req.URL.Host - if !hasPort(addr) { - addr += ":" + req.URL.Scheme + cm, err := t.connectMethodForRequest(req) + if err != nil { + return nil, err + } + + // Get the cached or newly-created connection to either the + // host (for http or https), the http proxy, or the http proxy + // pre-CONNECTed to https server. In any case, we'll be ready + // to send it requests. + pconn, err := t.getConn(cm) + if err != nil { + return nil, err } - var proxyURL *URL - proxyAuth := "" - proxy := "" - if !matchNoProxy(addr) { - proxy = os.Getenv("HTTP_PROXY") - if proxy == "" { - proxy = os.Getenv("http_proxy") + return pconn.roundTrip(req) +} + +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle in +// a "keep-alive" state. It does not interrupt any connections currently +// in use. +func (t *Transport) CloseIdleConnections() { + t.lk.Lock() + defer t.lk.Unlock() + if t.idleConn == nil { + return + } + for _, conns := range t.idleConn { + for _, pconn := range conns { + pconn.close() } } + t.idleConn = nil +} - var write = (*Request).Write +// +// Private implementation past this point. +// - if proxy != "" { - write = (*Request).WriteProxy - proxyURL, err = ParseRequestURL(proxy) +func (t *Transport) getenvEitherCase(k string) string { + if t.IgnoreEnvironment { + return "" + } + if v := t.getenv(strings.ToUpper(k)); v != "" { + return v + } + return t.getenv(strings.ToLower(k)) +} + +func (t *Transport) getenv(k string) string { + if t.IgnoreEnvironment { + return "" + } + return os.Getenv(k) +} + +func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) { + cm := &connectMethod{ + targetScheme: req.URL.Scheme, + targetAddr: canonicalAddr(req.URL), + } + + proxy := t.getenvEitherCase("HTTP_PROXY") + if proxy != "" && t.useProxy(cm.targetAddr) { + proxyURL, err := ParseRequestURL(proxy) if err != nil { return nil, os.ErrorString("invalid proxy address") } @@ -69,83 +135,452 @@ func (ct *transport) Do(req *Request) (resp *Response, err os.Error) { return nil, os.ErrorString("invalid proxy address") } } - addr = proxyURL.Host - proxyInfo := proxyURL.RawUserinfo - if proxyInfo != "" { - enc := base64.URLEncoding - encoded := make([]byte, enc.EncodedLen(len(proxyInfo))) - enc.Encode(encoded, []byte(proxyInfo)) - proxyAuth = "Basic " + string(encoded) + cm.proxyURL = proxyURL + } + return cm, nil +} + +// proxyAuth returns the Proxy-Authorization header to set +// on requests, if applicable. +func (cm *connectMethod) proxyAuth() string { + if cm.proxyURL == nil { + return "" + } + proxyInfo := cm.proxyURL.RawUserinfo + if proxyInfo != "" { + enc := base64.URLEncoding + encoded := make([]byte, enc.EncodedLen(len(proxyInfo))) + enc.Encode(encoded, []byte(proxyInfo)) + return "Basic " + string(encoded) + } + return "" +} + +func (t *Transport) putIdleConn(pconn *persistConn) { + t.lk.Lock() + defer t.lk.Unlock() + if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 { + pconn.close() + return + } + if pconn.isBroken() { + return + } + key := pconn.cacheKey + max := t.MaxIdleConnsPerHost + if max == 0 { + max = DefaultMaxIdleConnsPerHost + } + if len(t.idleConn[key]) >= max { + pconn.close() + return + } + t.idleConn[key] = append(t.idleConn[key], pconn) +} + +func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { + t.lk.Lock() + defer t.lk.Unlock() + if t.idleConn == nil { + t.idleConn = make(map[string][]*persistConn) + } + key := cm.String() + for { + pconns, ok := t.idleConn[key] + if !ok { + return nil + } + if len(pconns) == 1 { + pconn = pconns[0] + t.idleConn[key] = nil, false + } else { + // 2 or more cached connections; pop last + // TODO: queue? + pconn = pconns[len(pconns)-1] + t.idleConn[key] = pconns[0 : len(pconns)-1] } + if !pconn.isBroken() { + return + } + } + return +} + +// getConn dials and creates a new persistConn to the target as +// specified in the connectMethod. This includes doing a proxy CONNECT +// and/or setting up TLS. If this doesn't return an error, the persistConn +// is ready to write requests to. +func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { + if pc := t.getIdleConn(cm); pc != nil { + return pc, nil } - // Connect to server or proxy - conn, err := net.Dial("tcp", "", addr) + conn, err := net.Dial("tcp", cm.addr()) if err != nil { return nil, err } - if req.URL.Scheme == "http" { - // Include proxy http header if needed. - if proxyAuth != "" { - req.Header.Set("Proxy-Authorization", proxyAuth) - } - } else { // https - if proxyURL != nil { - // Ask proxy for direct connection to server. - // addr defaults above to ":https" but we need to use numbers - addr = req.URL.Host - if !hasPort(addr) { - addr += ":443" - } - fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\n", addr) - fmt.Fprintf(conn, "Host: %s\r\n", addr) - if proxyAuth != "" { - fmt.Fprintf(conn, "Proxy-Authorization: %s\r\n", proxyAuth) - } - fmt.Fprintf(conn, "\r\n") + pa := cm.proxyAuth() - // Read response. - // Okay to use and discard buffered reader here, because - // TLS server will not speak until spoken to. - br := bufio.NewReader(conn) - resp, err := ReadResponse(br, "CONNECT") - if err != nil { - return nil, err - } - if resp.StatusCode != 200 { - f := strings.Split(resp.Status, " ", 2) - return nil, os.ErrorString(f[1]) + pconn := &persistConn{ + t: t, + cacheKey: cm.String(), + conn: conn, + reqch: make(chan requestAndChan, 50), + } + newClientConnFunc := NewClientConn + + switch { + case cm.proxyURL == nil: + // Do nothing. + case cm.targetScheme == "http": + newClientConnFunc = NewProxyClientConn + if pa != "" { + pconn.mutateRequestFunc = func(req *Request) { + if req.Header == nil { + req.Header = make(Header) + } + req.Header.Set("Proxy-Authorization", pa) } } + case cm.targetScheme == "https": + fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\n", cm.targetAddr) + fmt.Fprintf(conn, "Host: %s\r\n", cm.targetAddr) + if pa != "" { + fmt.Fprintf(conn, "Proxy-Authorization: %s\r\n", pa) + } + fmt.Fprintf(conn, "\r\n") + // Read response. + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(conn) + resp, err := ReadResponse(br, "CONNECT") + if err != nil { + conn.Close() + return nil, err + } + if resp.StatusCode != 200 { + f := strings.Split(resp.Status, " ", 2) + conn.Close() + return nil, os.ErrorString(f[1]) + } + } + + if cm.targetScheme == "https" { // Initiate TLS and check remote host name against certificate. conn = tls.Client(conn, nil) if err = conn.(*tls.Conn).Handshake(); err != nil { return nil, err } - h := req.URL.Host - if hasPort(h) { - h = h[:strings.LastIndex(h, ":")] - } - if err = conn.(*tls.Conn).VerifyHostname(h); err != nil { + if err = conn.(*tls.Conn).VerifyHostname(cm.tlsHost()); err != nil { return nil, err } + pconn.conn = conn } - err = write(req, conn) - if err != nil { - conn.Close() - return nil, err + pconn.br = bufio.NewReader(pconn.conn) + pconn.cc = newClientConnFunc(conn, pconn.br) + pconn.cc.readRes = readResponseWithEOFSignal + go pconn.readLoop() + return pconn, nil +} + +// useProxy returns true if requests to addr should use a proxy, +// according to the NO_PROXY or no_proxy environment variable. +func (t *Transport) useProxy(addr string) bool { + if len(addr) == 0 { + return true + } + no_proxy := t.getenvEitherCase("NO_PROXY") + if no_proxy == "*" { + return false } - reader := bufio.NewReader(conn) - resp, err = ReadResponse(reader, req.Method) + addr = strings.ToLower(strings.TrimSpace(addr)) + if hasPort(addr) { + addr = addr[:strings.LastIndex(addr, ":")] + } + + for _, p := range strings.Split(no_proxy, ",", -1) { + p = strings.ToLower(strings.TrimSpace(p)) + if len(p) == 0 { + continue + } + if hasPort(p) { + p = p[:strings.LastIndex(p, ":")] + } + if addr == p || (p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:])) { + return false + } + } + return true +} + +// connectMethod is the map key (in its String form) for keeping persistent +// TCP connections alive for subsequent HTTP requests. +// +// A connect method may be of the following types: +// +// Cache key form Description +// ----------------- ------------------------- +// ||http|foo.com http directly to server, no proxy +// ||https|foo.com https directly to server, no proxy +// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com +// http://proxy.com|http http to proxy, http to anywhere after that +// +// Note: no support to https to the proxy yet. +// +type connectMethod struct { + proxyURL *URL // "" for no proxy, else full proxy URL + targetScheme string // "http" or "https" + targetAddr string // Not used if proxy + http targetScheme (4th example in table) +} + +func (ck *connectMethod) String() string { + proxyStr := "" + if ck.proxyURL != nil { + proxyStr = ck.proxyURL.String() + } + return strings.Join([]string{proxyStr, ck.targetScheme, ck.targetAddr}, "|") +} + +// addr returns the first hop "host:port" to which we need to TCP connect. +func (cm *connectMethod) addr() string { + if cm.proxyURL != nil { + return canonicalAddr(cm.proxyURL) + } + return cm.targetAddr +} + +// tlsHost returns the host name to match against the peer's +// TLS certificate. +func (cm *connectMethod) tlsHost() string { + h := cm.targetAddr + if hasPort(h) { + h = h[:strings.LastIndex(h, ":")] + } + return h +} + +type readResult struct { + res *Response // either res or err will be set + err os.Error +} + +type writeRequest struct { + // Set by client (in pc.roundTrip) + req *Request + resch chan *readResult + + // Set by writeLoop if an error writing headers. + writeErr os.Error +} + +// persistConn wraps a connection, usually a persistent one +// (but may be used for non-keep-alive requests as well) +type persistConn struct { + t *Transport + cacheKey string // its connectMethod.String() + conn net.Conn + cc *ClientConn + br *bufio.Reader + reqch chan requestAndChan // written by roundTrip(); read by readLoop() + mutateRequestFunc func(*Request) // nil or func to modify each outbound request + + lk sync.Mutex // guards numExpectedResponses and broken + numExpectedResponses int + broken bool // an error has happened on this connection; marked broken so it's not reused. +} + +func (pc *persistConn) isBroken() bool { + pc.lk.Lock() + defer pc.lk.Unlock() + return pc.broken +} + +func (pc *persistConn) expectingResponse() bool { + pc.lk.Lock() + defer pc.lk.Unlock() + return pc.numExpectedResponses > 0 +} + +func (pc *persistConn) readLoop() { + alive := true + for alive { + pb, err := pc.br.Peek(1) + if err != nil { + if (err == os.EOF || err == os.EINVAL) && !pc.expectingResponse() { + // Remote side closed on us. (We probably hit their + // max idle timeout) + pc.close() + return + } + } + if !pc.expectingResponse() { + log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", + string(pb), err) + pc.close() + return + } + + rc := <-pc.reqch + resp, err := pc.cc.Read(rc.req) + + if err == ErrPersistEOF { + // Succeeded, but we can't send any more + // persistent connections on this again. We + // hide this error to upstream callers. + alive = false + err = nil + } else if err != nil || rc.req.Close { + alive = false + } + + hasBody := resp != nil && resp.ContentLength != 0 + var waitForBodyRead chan bool + if alive { + if hasBody { + waitForBodyRead = make(chan bool) + resp.Body.(*bodyEOFSignal).fn = func() { + pc.t.putIdleConn(pc) + waitForBodyRead <- true + } + } else { + pc.t.putIdleConn(pc) + } + } + + rc.ch <- responseAndError{resp, err} + + // Wait for the just-returned response body to be fully consumed + // before we race and peek on the underlying bufio reader. + if waitForBodyRead != nil { + <-waitForBodyRead + } + } +} + +type responseAndError struct { + res *Response + err os.Error +} + +type requestAndChan struct { + req *Request + ch chan responseAndError +} + +func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { + if pc.mutateRequestFunc != nil { + pc.mutateRequestFunc(req) + } + + // Ask for a compressed version if the caller didn't set their + // own value for Accept-Encoding. We only attempted to + // uncompress the gzip stream if we were the layer that + // requested it. + requestedGzip := false + if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" { + // Request gzip only, not deflate. Deflate is ambiguous and + // as universally supported anyway. + // See: http://www.gzip.org/zlib/zlib_faq.html#faq38 + requestedGzip = true + req.Header.Set("Accept-Encoding", "gzip") + } + + pc.lk.Lock() + pc.numExpectedResponses++ + pc.lk.Unlock() + + err = pc.cc.Write(req) if err != nil { - conn.Close() - return nil, err + pc.close() + return + } + + ch := make(chan responseAndError, 1) + pc.reqch <- requestAndChan{req, ch} + re := <-ch + pc.lk.Lock() + pc.numExpectedResponses-- + pc.lk.Unlock() + + if re.err == nil && requestedGzip && re.res.Header.Get("Content-Encoding") == "gzip" { + re.res.Header.Del("Content-Encoding") + re.res.Header.Del("Content-Length") + re.res.ContentLength = -1 + var err os.Error + re.res.Body, err = gzip.NewReader(re.res.Body) + if err != nil { + pc.close() + return nil, err + } + } + + return re.res, re.err +} + +func (pc *persistConn) close() { + pc.lk.Lock() + defer pc.lk.Unlock() + pc.broken = true + pc.cc.Close() + pc.conn.Close() + pc.mutateRequestFunc = nil +} + +var portMap = map[string]string{ + "http": "80", + "https": "443", +} + +// canonicalAddr returns url.Host but always with a ":port" suffix +func canonicalAddr(url *URL) string { + addr := url.Host + if !hasPort(addr) { + return addr + ":" + portMap[url.Scheme] } + return addr +} + +func responseIsKeepAlive(res *Response) bool { + // TODO: implement. for now just always shutting down the connection. + return false +} - resp.Body = readClose{resp.Body, conn} +// readResponseWithEOFSignal is a wrapper around ReadResponse that replaces +// the response body with a bodyEOFSignal-wrapped version. +func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) { + resp, err = ReadResponse(r, requestMethod) + if err == nil && resp.ContentLength != 0 { + resp.Body = &bodyEOFSignal{resp.Body, nil} + } + return +} + +// bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most +// once, right before the final Read() or Close() call returns, but after +// EOF has been seen. +type bodyEOFSignal struct { + body io.ReadCloser + fn func() +} + +func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) { + n, err = es.body.Read(p) + if err == os.EOF && es.fn != nil { + es.fn() + es.fn = nil + } + return +} + +func (es *bodyEOFSignal) Close() (err os.Error) { + err = es.body.Close() + if err == nil && es.fn != nil { + es.fn() + es.fn = nil + } return } diff --git a/src/pkg/http/transport_test.go b/src/pkg/http/transport_test.go new file mode 100644 index 000000000..f83deedfc --- /dev/null +++ b/src/pkg/http/transport_test.go @@ -0,0 +1,450 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for transport.go + +package http_test + +import ( + "bytes" + "compress/gzip" + "fmt" + . "http" + "http/httptest" + "io/ioutil" + "os" + "testing" + "time" +) + +// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close +// and then verify that the final 2 responses get errors back. + +// hostPortHandler writes back the client's "host:port". +var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { + if r.FormValue("close") == "true" { + w.Header().Set("Connection", "close") + } + w.Write([]byte(r.RemoteAddr)) +}) + +// Two subsequent requests and verify their response is the same. +// The response from the server is our own IP:port +func TestTransportKeepAlives(t *testing.T) { + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + for _, disableKeepAlive := range []bool{false, true} { + tr := &Transport{DisableKeepAlives: disableKeepAlive} + c := &Client{Transport: tr} + + fetch := func(n int) string { + res, _, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + + bodiesDiffer := body1 != body2 + if bodiesDiffer != disableKeepAlive { + t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", + disableKeepAlive, bodiesDiffer, body1, body2) + } + } +} + +func TestTransportConnectionCloseOnResponse(t *testing.T) { + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + for _, connectionClose := range []bool{false, true} { + tr := &Transport{} + c := &Client{Transport: tr} + + fetch := func(n int) string { + req := new(Request) + var err os.Error + req.URL, err = ParseURL(ts.URL + fmt.Sprintf("?close=%v", connectionClose)) + if err != nil { + t.Fatalf("URL parse error: %v", err) + } + req.Method = "GET" + req.Proto = "HTTP/1.1" + req.ProtoMajor = 1 + req.ProtoMinor = 1 + + res, err := c.Do(req) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) + } + body, err := ioutil.ReadAll(res.Body) + defer res.Body.Close() + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + bodiesDiffer := body1 != body2 + if bodiesDiffer != connectionClose { + t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", + connectionClose, bodiesDiffer, body1, body2) + } + } +} + +func TestTransportConnectionCloseOnRequest(t *testing.T) { + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + for _, connectionClose := range []bool{false, true} { + tr := &Transport{} + c := &Client{Transport: tr} + + fetch := func(n int) string { + req := new(Request) + var err os.Error + req.URL, err = ParseURL(ts.URL) + if err != nil { + t.Fatalf("URL parse error: %v", err) + } + req.Method = "GET" + req.Proto = "HTTP/1.1" + req.ProtoMajor = 1 + req.ProtoMinor = 1 + req.Close = connectionClose + + res, err := c.Do(req) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + bodiesDiffer := body1 != body2 + if bodiesDiffer != connectionClose { + t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", + connectionClose, bodiesDiffer, body1, body2) + } + } +} + +func TestTransportIdleCacheKeys(t *testing.T) { + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + c := &Client{Transport: tr} + + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) + } + + resp, _, err := c.Get(ts.URL) + if err != nil { + t.Error(err) + } + ioutil.ReadAll(resp.Body) + + keys := tr.IdleConnKeysForTesting() + if e, g := 1, len(keys); e != g { + t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g) + } + + if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e { + t.Errorf("Expected idle cache key %q; got %q", e, keys[0]) + } + + tr.CloseIdleConnections() + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) + } +} + +func TestTransportMaxPerHostIdleConns(t *testing.T) { + ch := make(chan string) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte(<-ch)) + })) + defer ts.Close() + maxIdleConns := 2 + tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConns} + c := &Client{Transport: tr} + + // Start 3 outstanding requests (will hang until we write to + // ch) + donech := make(chan bool) + doReq := func() { + resp, _, err := c.Get(ts.URL) + if err != nil { + t.Error(err) + } + ioutil.ReadAll(resp.Body) + donech <- true + } + go doReq() + go doReq() + go doReq() + + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) + } + + ch <- "res1" + <-donech + keys := tr.IdleConnKeysForTesting() + if e, g := 1, len(keys); e != g { + t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g) + } + cacheKey := "|http|" + ts.Listener.Addr().String() + if keys[0] != cacheKey { + t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0]) + } + if e, g := 1, tr.IdleConnCountForTesting(cacheKey); e != g { + t.Errorf("after first response, expected %d idle conns; got %d", e, g) + } + + ch <- "res2" + <-donech + if e, g := 2, tr.IdleConnCountForTesting(cacheKey); e != g { + t.Errorf("after second response, expected %d idle conns; got %d", e, g) + } + + ch <- "res3" + <-donech + if e, g := maxIdleConns, tr.IdleConnCountForTesting(cacheKey); e != g { + t.Errorf("after third response, still expected %d idle conns; got %d", e, g) + } +} + +func TestTransportServerClosingUnexpectedly(t *testing.T) { + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + tr := &Transport{} + c := &Client{Transport: tr} + + fetch := func(n int) string { + res, _, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("error in req #%d, GET: %v", n, err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("error in req #%d, ReadAll: %v", n, err) + } + res.Body.Close() + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + + ts.CloseClientConnections() // surprise! + time.Sleep(25e6) // idle for a bit (test is inherently racey, but expectedly) + + body3 := fetch(3) + + if body1 != body2 { + t.Errorf("expected body1 and body2 to be equal") + } + if body2 == body3 { + t.Errorf("expected body2 and body3 to be different") + } +} + +// TestTransportHeadResponses verifies that we deal with Content-Lengths +// with no bodies properly +func TestTransportHeadResponses(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "HEAD" { + panic("expected HEAD; got " + r.Method) + } + w.Header().Set("Content-Length", "123") + w.WriteHeader(200) + })) + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + c := &Client{Transport: tr} + for i := 0; i < 2; i++ { + res, err := c.Head(ts.URL) + if err != nil { + t.Errorf("error on loop %d: %v", i, err) + } + if e, g := "123", res.Header.Get("Content-Length"); e != g { + t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) + } + if e, g := int64(0), res.ContentLength; e != g { + t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) + } + } +} + +// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding +// on responses to HEAD requests. +func TestTransportHeadChunkedResponse(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "HEAD" { + panic("expected HEAD; got " + r.Method) + } + w.Header().Set("Transfer-Encoding", "chunked") // client should ignore + w.Header().Set("x-client-ipport", r.RemoteAddr) + w.WriteHeader(200) + })) + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + c := &Client{Transport: tr} + + res1, err := c.Head(ts.URL) + if err != nil { + t.Fatalf("request 1 error: %v", err) + } + res2, err := c.Head(ts.URL) + if err != nil { + t.Fatalf("request 2 error: %v", err) + } + if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 { + t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2) + } +} + +func TestTransportNilURL(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hi") + })) + defer ts.Close() + + req := new(Request) + req.URL = nil // what we're actually testing + req.Method = "GET" + req.RawURL = ts.URL + req.Proto = "HTTP/1.1" + req.ProtoMajor = 1 + req.ProtoMinor = 1 + req.Header = make(Header) + + tr := &Transport{} + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected RoundTrip error: %v", err) + } + body, err := ioutil.ReadAll(res.Body) + if g, e := string(body), "Hi"; g != e { + t.Fatalf("Expected response body of %q; got %q", e, g) + } +} + +func TestTransportGzip(t *testing.T) { + const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if g, e := r.Header.Get("Accept-Encoding"), "gzip"; g != e { + t.Errorf("Accept-Encoding = %q, want %q", g, e) + } + w.Header().Set("Content-Encoding", "gzip") + gz, _ := gzip.NewWriter(w) + defer gz.Close() + gz.Write([]byte(testString)) + + })) + defer ts.Close() + + c := &Client{Transport: &Transport{}} + res, _, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if g, e := string(body), testString; g != e { + t.Fatalf("body = %q; want %q", g, e) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } +} + +// TestTransportGzipRecursive sends a gzip quine and checks that the +// client gets the same value back. This is more cute than anything, +// but checks that we don't recurse forever, and checks that +// Content-Encoding is removed. +func TestTransportGzipRecursive(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "gzip") + w.Write(rgz) + })) + defer ts.Close() + + c := &Client{Transport: &Transport{}} + res, _, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(body, rgz) { + t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x", + body, rgz) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } +} + +// rgz is a gzip quine that uncompresses to itself. +var rgz = []byte{ + 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73, + 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0, + 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, + 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, + 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60, + 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, + 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, + 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, + 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16, + 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, + 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, + 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, + 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, + 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, + 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, + 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, + 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, + 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, + 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, + 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, + 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, + 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, + 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, + 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, + 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, + 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06, + 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, + 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, + 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, + 0x00, 0x00, +} diff --git a/src/pkg/http/triv.go b/src/pkg/http/triv.go index 52d521d3d..bff6a106d 100644 --- a/src/pkg/http/triv.go +++ b/src/pkg/http/triv.go @@ -56,7 +56,7 @@ func (ctr *Counter) ServeHTTP(w http.ResponseWriter, req *http.Request) { var booleanflag = flag.Bool("boolean", true, "another flag for testing") func FlagServer(w http.ResponseWriter, req *http.Request) { - w.SetHeader("content-type", "text/plain; charset=utf-8") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") fmt.Fprint(w, "Flags:\n") flag.VisitAll(func(f *flag.Flag) { if f.Value.String() != f.DefValue { @@ -93,13 +93,14 @@ func (ch Chan) ServeHTTP(w http.ResponseWriter, req *http.Request) { // exec a program, redirecting output func DateServer(rw http.ResponseWriter, req *http.Request) { - rw.SetHeader("content-type", "text/plain; charset=utf-8") + rw.Header().Set("Content-Type", "text/plain; charset=utf-8") r, w, err := os.Pipe() if err != nil { fmt.Fprintf(rw, "pipe: %s\n", err) return } - p, err := os.StartProcess("/bin/date", []string{"date"}, os.Environ(), "", []*os.File{nil, w, w}) + + p, err := os.StartProcess("/bin/date", []string{"date"}, &os.ProcAttr{Files: []*os.File{nil, w, w}}) defer r.Close() w.Close() if err != nil { diff --git a/src/pkg/http/url.go b/src/pkg/http/url.go index efd90d81e..0fc0cb2d7 100644 --- a/src/pkg/http/url.go +++ b/src/pkg/http/url.go @@ -213,8 +213,8 @@ func urlEscape(s string, mode encoding) string { j++ case shouldEscape(c, mode): t[j] = '%' - t[j+1] = "0123456789abcdef"[c>>4] - t[j+2] = "0123456789abcdef"[c&15] + t[j+1] = "0123456789ABCDEF"[c>>4] + t[j+2] = "0123456789ABCDEF"[c&15] j += 3 default: t[j] = s[i] diff --git a/src/pkg/http/url_test.go b/src/pkg/http/url_test.go index 0801f7ff3..d8863f3d3 100644 --- a/src/pkg/http/url_test.go +++ b/src/pkg/http/url_test.go @@ -490,7 +490,7 @@ var escapeTests = []URLEscapeTest{ }, { " ?&=#+%!<>#\"{}|\\^[]`☺\t", - "+%3f%26%3d%23%2b%25!%3c%3e%23%22%7b%7d%7c%5c%5e%5b%5d%60%e2%98%ba%09", + "+%3F%26%3D%23%2B%25!%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09", nil, }, } @@ -519,7 +519,7 @@ type UserinfoTest struct { var userinfoTests = []UserinfoTest{ {"user", "password", "user:password"}, {"foo:bar", "~!@#$%^&*()_+{}|[]\\-=`:;'\"<>?,./", - "foo%3abar:~!%40%23$%25%5e&*()_+%7b%7d%7c%5b%5d%5c-=%60%3a;'%22%3c%3e?,.%2f"}, + "foo%3Abar:~!%40%23$%25%5E&*()_+%7B%7D%7C%5B%5D%5C-=%60%3A;'%22%3C%3E?,.%2F"}, } func TestEscapeUserinfo(t *testing.T) { |