diff options
author | Ondřej Surý <ondrej@sury.org> | 2011-06-30 15:34:22 +0200 |
---|---|---|
committer | Ondřej Surý <ondrej@sury.org> | 2011-06-30 15:34:22 +0200 |
commit | d39f5aa373a4422f7a5f3ee764fb0f6b0b719d61 (patch) | |
tree | 1833f8b72a4b3a8f00d0d143b079a8fcad01c6ae /src/pkg/http | |
parent | 8652e6c371b8905498d3d314491d36c58d5f68d5 (diff) | |
download | golang-upstream/58.tar.gz |
Imported Upstream version 58upstream/58
Diffstat (limited to 'src/pkg/http')
38 files changed, 2846 insertions, 593 deletions
diff --git a/src/pkg/http/cgi/child.go b/src/pkg/http/cgi/child.go index c7d48b9eb..e1ad7ad32 100644 --- a/src/pkg/http/cgi/child.go +++ b/src/pkg/http/cgi/child.go @@ -9,10 +9,12 @@ package cgi import ( "bufio" + "crypto/tls" "fmt" "http" "io" "io/ioutil" + "net" "os" "strconv" "strings" @@ -21,8 +23,16 @@ import ( // Request returns the HTTP request as represented in the current // environment. This assumes the current program is being run // by a web server in a CGI environment. +// The returned Request's Body is populated, if applicable. func Request() (*http.Request, os.Error) { - return requestFromEnvironment(envMap(os.Environ())) + r, err := RequestFromMap(envMap(os.Environ())) + if err != nil { + return nil, err + } + if r.ContentLength > 0 { + r.Body = ioutil.NopCloser(io.LimitReader(os.Stdin, r.ContentLength)) + } + return r, nil } func envMap(env []string) map[string]string { @@ -42,37 +52,44 @@ var skipHeader = map[string]bool{ "HTTP_USER_AGENT": true, } -func requestFromEnvironment(env map[string]string) (*http.Request, os.Error) { +// RequestFromMap creates an http.Request from CGI variables. +// The returned Request's Body field is not populated. +func RequestFromMap(params map[string]string) (*http.Request, os.Error) { r := new(http.Request) - r.Method = env["REQUEST_METHOD"] + r.Method = params["REQUEST_METHOD"] if r.Method == "" { return nil, os.NewError("cgi: no REQUEST_METHOD in environment") } + + r.Proto = params["SERVER_PROTOCOL"] + var ok bool + r.ProtoMajor, r.ProtoMinor, ok = http.ParseHTTPVersion(r.Proto) + if !ok { + return nil, os.NewError("cgi: invalid SERVER_PROTOCOL version") + } + r.Close = true r.Trailer = http.Header{} r.Header = http.Header{} - r.Host = env["HTTP_HOST"] - r.Referer = env["HTTP_REFERER"] - r.UserAgent = env["HTTP_USER_AGENT"] + r.Host = params["HTTP_HOST"] + r.Referer = params["HTTP_REFERER"] + r.UserAgent = params["HTTP_USER_AGENT"] - // CGI doesn't allow chunked requests, so these should all be accurate: - r.Proto = "HTTP/1.0" - r.ProtoMajor = 1 - r.ProtoMinor = 0 - r.TransferEncoding = nil - - if lenstr := env["CONTENT_LENGTH"]; lenstr != "" { + if lenstr := params["CONTENT_LENGTH"]; lenstr != "" { clen, err := strconv.Atoi64(lenstr) if err != nil { return nil, os.NewError("cgi: bad CONTENT_LENGTH in environment: " + lenstr) } r.ContentLength = clen - r.Body = ioutil.NopCloser(io.LimitReader(os.Stdin, clen)) + } + + if ct := params["CONTENT_TYPE"]; ct != "" { + r.Header.Set("Content-Type", ct) } // Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers - for k, v := range env { + for k, v := range params { if !strings.HasPrefix(k, "HTTP_") || skipHeader[k] { continue } @@ -84,7 +101,7 @@ func requestFromEnvironment(env map[string]string) (*http.Request, os.Error) { if r.Host != "" { // Hostname is provided, so we can reasonably construct a URL, // even if we have to assume 'http' for the scheme. - r.RawURL = "http://" + r.Host + env["REQUEST_URI"] + r.RawURL = "http://" + r.Host + params["REQUEST_URI"] url, err := http.ParseURL(r.RawURL) if err != nil { return nil, os.NewError("cgi: failed to parse host and REQUEST_URI into a URL: " + r.RawURL) @@ -94,13 +111,25 @@ func requestFromEnvironment(env map[string]string) (*http.Request, os.Error) { // Fallback logic if we don't have a Host header or the URL // failed to parse if r.URL == nil { - r.RawURL = env["REQUEST_URI"] + r.RawURL = params["REQUEST_URI"] url, err := http.ParseURL(r.RawURL) if err != nil { return nil, os.NewError("cgi: failed to parse REQUEST_URI into a URL: " + r.RawURL) } r.URL = url } + + // There's apparently a de-facto standard for this. + // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636 + if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" { + r.TLS = &tls.ConnectionState{HandshakeComplete: true} + } + + // Request.RemoteAddr has its port set by Go's standard http + // server, so we do here too. We don't have one, though, so we + // use a dummy one. + r.RemoteAddr = net.JoinHostPort(params["REMOTE_ADDR"], "0") + return r, nil } @@ -139,10 +168,6 @@ func (r *response) Flush() { r.bufw.Flush() } -func (r *response) RemoteAddr() string { - return os.Getenv("REMOTE_ADDR") -} - func (r *response) Header() http.Header { return r.header } @@ -168,25 +193,7 @@ func (r *response) WriteHeader(code int) { r.header.Add("Content-Type", "text/html; charset=utf-8") } - // TODO: add a method on http.Header to write itself to an io.Writer? - // This is duplicated code. - for k, vv := range r.header { - for _, v := range vv { - v = strings.Replace(v, "\n", "", -1) - v = strings.Replace(v, "\r", "", -1) - v = strings.TrimSpace(v) - fmt.Fprintf(r.bufw, "%s: %s\r\n", k, v) - } - } - r.bufw.Write([]byte("\r\n")) + r.header.Write(r.bufw) + r.bufw.WriteString("\r\n") r.bufw.Flush() } - -func (r *response) UsingTLS() bool { - // There's apparently a de-facto standard for this. - // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636 - if s := os.Getenv("HTTPS"); s == "on" || s == "ON" || s == "1" { - return true - } - return false -} diff --git a/src/pkg/http/cgi/child_test.go b/src/pkg/http/cgi/child_test.go index db0e09cf6..d12947814 100644 --- a/src/pkg/http/cgi/child_test.go +++ b/src/pkg/http/cgi/child_test.go @@ -12,6 +12,7 @@ import ( func TestRequest(t *testing.T) { env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", "REQUEST_METHOD": "GET", "HTTP_HOST": "example.com", "HTTP_REFERER": "elsewhere", @@ -19,10 +20,13 @@ func TestRequest(t *testing.T) { "HTTP_FOO_BAR": "baz", "REQUEST_URI": "/path?a=b", "CONTENT_LENGTH": "123", + "CONTENT_TYPE": "text/xml", + "HTTPS": "1", + "REMOTE_ADDR": "5.6.7.8", } - req, err := requestFromEnvironment(env) + req, err := RequestFromMap(env) if err != nil { - t.Fatalf("requestFromEnvironment: %v", err) + t.Fatalf("RequestFromMap: %v", err) } if g, e := req.UserAgent, "goclient"; e != g { t.Errorf("expected UserAgent %q; got %q", e, g) @@ -34,6 +38,9 @@ func TestRequest(t *testing.T) { // Tests that we don't put recognized headers in the map t.Errorf("expected User-Agent %q; got %q", e, g) } + if g, e := req.Header.Get("Content-Type"), "text/xml"; e != g { + t.Errorf("expected Content-Type %q; got %q", e, g) + } if g, e := req.ContentLength, int64(123); e != g { t.Errorf("expected ContentLength %d; got %d", e, g) } @@ -58,18 +65,25 @@ func TestRequest(t *testing.T) { if req.Trailer == nil { t.Errorf("unexpected nil Trailer") } + if req.TLS == nil { + t.Errorf("expected non-nil TLS") + } + if e, g := "5.6.7.8:0", req.RemoteAddr; e != g { + t.Errorf("RemoteAddr: got %q; want %q", g, e) + } } func TestRequestWithoutHost(t *testing.T) { env := map[string]string{ - "HTTP_HOST": "", - "REQUEST_METHOD": "GET", - "REQUEST_URI": "/path?a=b", - "CONTENT_LENGTH": "123", + "SERVER_PROTOCOL": "HTTP/1.1", + "HTTP_HOST": "", + "REQUEST_METHOD": "GET", + "REQUEST_URI": "/path?a=b", + "CONTENT_LENGTH": "123", } - req, err := requestFromEnvironment(env) + req, err := RequestFromMap(env) if err != nil { - t.Fatalf("requestFromEnvironment: %v", err) + t.Fatalf("RequestFromMap: %v", err) } if g, e := req.RawURL, "/path?a=b"; e != g { t.Errorf("expected RawURL %q; got %q", e, g) diff --git a/src/pkg/http/cgi/host.go b/src/pkg/http/cgi/host.go index 136d4e4ee..7ab3f9247 100644 --- a/src/pkg/http/cgi/host.go +++ b/src/pkg/http/cgi/host.go @@ -36,7 +36,9 @@ var osDefaultInheritEnv = map[string][]string{ "darwin": []string{"DYLD_LIBRARY_PATH"}, "freebsd": []string{"LD_LIBRARY_PATH"}, "hpux": []string{"LD_LIBRARY_PATH", "SHLIB_PATH"}, + "irix": []string{"LD_LIBRARY_PATH", "LD_LIBRARYN32_PATH", "LD_LIBRARY64_PATH"}, "linux": []string{"LD_LIBRARY_PATH"}, + "solaris": []string{"LD_LIBRARY_PATH", "LD_LIBRARY_PATH_32", "LD_LIBRARY_PATH_64"}, "windows": []string{"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"}, } @@ -86,6 +88,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { env := []string{ "SERVER_SOFTWARE=go", "SERVER_NAME=" + req.Host, + "SERVER_PROTOCOL=HTTP/1.1", "HTTP_HOST=" + req.Host, "GATEWAY_INTERFACE=CGI/1.1", "REQUEST_METHOD=" + req.Method, @@ -153,34 +156,35 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { cwd = "." } - args := []string{h.Path} - args = append(args, h.Args...) - - cmd, err := exec.Run( - pathBase, - args, - env, - cwd, - exec.Pipe, // stdin - exec.Pipe, // stdout - exec.PassThrough, // stderr (for now) - ) - if err != nil { + internalError := func(err os.Error) { rw.WriteHeader(http.StatusInternalServerError) h.printf("CGI error: %v", err) - return } - defer func() { - cmd.Stdin.Close() - cmd.Stdout.Close() - cmd.Wait(0) // no zombies - }() + cmd := &exec.Cmd{ + Path: pathBase, + Args: append([]string{h.Path}, h.Args...), + Dir: cwd, + Env: env, + Stderr: os.Stderr, // for now + } if req.ContentLength != 0 { - go io.Copy(cmd.Stdin, req.Body) + cmd.Stdin = req.Body + } + stdoutRead, err := cmd.StdoutPipe() + if err != nil { + internalError(err) + return + } + + err = cmd.Start() + if err != nil { + internalError(err) + return } + defer cmd.Wait() - linebody, _ := bufio.NewReaderSize(cmd.Stdout, 1024) + linebody, _ := bufio.NewReaderSize(stdoutRead, 1024) headers := make(http.Header) statusCode := 0 for { diff --git a/src/pkg/http/cgi/host_test.go b/src/pkg/http/cgi/host_test.go index 9ac085f2f..bbdb715cf 100644 --- a/src/pkg/http/cgi/host_test.go +++ b/src/pkg/http/cgi/host_test.go @@ -17,20 +17,6 @@ import ( "testing" ) -var cgiScriptWorks = canRun("./testdata/test.cgi") - -func canRun(s string) bool { - c, err := exec.Run(s, []string{s}, nil, ".", exec.DevNull, exec.DevNull, exec.DevNull) - if err != nil { - return false - } - w, err := c.Wait(0) - if err != nil { - return false - } - return w.Exited() && w.ExitStatus() == 0 -} - func newRequest(httpreq string) *http.Request { buf := bufio.NewReader(strings.NewReader(httpreq)) req, err := http.ReadRequest(buf) @@ -76,8 +62,15 @@ readlines: return rw } +var cgiTested = false +var cgiWorks bool + func skipTest(t *testing.T) bool { - if !cgiScriptWorks { + if !cgiTested { + cgiTested = true + cgiWorks = exec.Command("./testdata/test.cgi").Run() == nil + } + if !cgiWorks { // No Perl on Windows, needed by test.cgi // TODO: make the child process be Go, not Perl. t.Logf("Skipping test: test.cgi failed.") diff --git a/src/pkg/http/chunked.go b/src/pkg/http/chunked.go index 66195f06b..59121c5a2 100644 --- a/src/pkg/http/chunked.go +++ b/src/pkg/http/chunked.go @@ -6,19 +6,29 @@ package http import ( "io" + "log" "os" "strconv" ) // NewChunkedWriter returns a new writer that translates writes into HTTP -// "chunked" format before writing them to w. Closing the returned writer +// "chunked" format before writing them to w. Closing the returned writer // sends the final 0-length chunk that marks the end of the stream. +// +// NewChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using NewChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. func NewChunkedWriter(w io.Writer) io.WriteCloser { + if _, bad := w.(*response); bad { + log.Printf("warning: using NewChunkedWriter in an http.Handler; expect corrupt output") + } return &chunkedWriter{w} } // Writing to ChunkedWriter translates to writing in HTTP chunked Transfer -// Encoding wire format to the undering Wire writer. +// Encoding wire format to the underlying Wire writer. type chunkedWriter struct { Wire io.Writer } diff --git a/src/pkg/http/client.go b/src/pkg/http/client.go index d73cbc855..71b037042 100644 --- a/src/pkg/http/client.go +++ b/src/pkg/http/client.go @@ -7,13 +7,10 @@ package http import ( - "bytes" "encoding/base64" "fmt" "io" - "io/ioutil" "os" - "strconv" "strings" ) @@ -74,6 +71,9 @@ type readClose struct { // // Generally Get, Post, or PostForm will be used instead of Do. func (c *Client) Do(req *Request) (resp *Response, err os.Error) { + if req.Method == "GET" || req.Method == "HEAD" { + return c.doFollowingRedirects(req) + } return send(req, c.Transport) } @@ -97,13 +97,10 @@ func send(req *Request, t RoundTripper) (resp *Response, err os.Error) { info := req.URL.RawUserinfo if len(info) > 0 { - enc := base64.URLEncoding - encoded := make([]byte, enc.EncodedLen(len(info))) - enc.Encode(encoded, []byte(info)) if req.Header == nil { req.Header = make(Header) } - req.Header.Set("Authorization", "Basic "+string(encoded)) + req.Header.Set("Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(info))) } return t.RoundTrip(req) } @@ -126,13 +123,10 @@ func shouldRedirect(statusCode int) bool { // 303 (See Other) // 307 (Temporary Redirect) // -// finalURL is the URL from which the response was fetched -- identical to the -// input URL unless redirects were followed. -// // Caller should close r.Body when done reading from it. // // Get is a convenience wrapper around DefaultClient.Get. -func Get(url string) (r *Response, finalURL string, err os.Error) { +func Get(url string) (r *Response, err os.Error) { return DefaultClient.Get(url) } @@ -145,11 +139,16 @@ func Get(url string) (r *Response, finalURL string, err os.Error) { // 303 (See Other) // 307 (Temporary Redirect) // -// finalURL is the URL from which the response was fetched -- identical -// to the input URL unless redirects were followed. -// // Caller should close r.Body when done reading from it. -func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { +func (c *Client) Get(url string) (r *Response, err os.Error) { + req, err := NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return c.doFollowingRedirects(req) +} + +func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err os.Error) { // TODO: if/when we add cookie support, the redirected request shouldn't // necessarily supply the same cookies as the original. var base *URL @@ -159,33 +158,33 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { } var via []*Request + req := ireq + url := "" // next relative or absolute URL to fetch (after first request) for redirect := 0; ; redirect++ { - var req Request - req.Method = "GET" - req.Header = make(Header) - if base == nil { - req.URL, err = ParseURL(url) - } else { + if redirect != 0 { + req = new(Request) + req.Method = ireq.Method + req.Header = make(Header) req.URL, err = base.ParseURL(url) - } - if err != nil { - break - } - if len(via) > 0 { - // Add the Referer header. - lastReq := via[len(via)-1] - if lastReq.URL.Scheme != "https" { - req.Referer = lastReq.URL.String() - } - - err = redirectChecker(&req, via) if err != nil { break } + if len(via) > 0 { + // Add the Referer header. + lastReq := via[len(via)-1] + if lastReq.URL.Scheme != "https" { + req.Referer = lastReq.URL.String() + } + + err = redirectChecker(req, via) + if err != nil { + break + } + } } url = req.URL.String() - if r, err = send(&req, c.Transport); err != nil { + if r, err = send(req, c.Transport); err != nil { break } if shouldRedirect(r.StatusCode) { @@ -195,14 +194,14 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { break } base = req.URL - via = append(via, &req) + via = append(via, req) continue } - finalURL = url return } - err = &URLError{"Get", url, err} + method := ireq.Method + err = &URLError{method[0:1] + strings.ToLower(method[1:]), url, err} return } @@ -226,23 +225,12 @@ func Post(url string, bodyType string, body io.Reader) (r *Response, err os.Erro // // Caller should close r.Body when done reading from it. func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response, err os.Error) { - var req Request - req.Method = "POST" - req.ProtoMajor = 1 - req.ProtoMinor = 1 - req.Close = true - req.Body = ioutil.NopCloser(body) - req.Header = Header{ - "Content-Type": {bodyType}, - } - req.TransferEncoding = []string{"chunked"} - - req.URL, err = ParseURL(url) + req, err := NewRequest("POST", url, body) if err != nil { return nil, err } - - return send(&req, c.Transport) + req.Header.Set("Content-Type", bodyType) + return send(req, c.Transport) } // PostForm issues a POST to the specified URL, @@ -251,7 +239,7 @@ func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response, // Caller should close r.Body when done reading from it. // // PostForm is a wrapper around DefaultClient.PostForm -func PostForm(url string, data map[string]string) (r *Response, err os.Error) { +func PostForm(url string, data Values) (r *Response, err os.Error) { return DefaultClient.PostForm(url, data) } @@ -259,50 +247,36 @@ func PostForm(url string, data map[string]string) (r *Response, err os.Error) { // with data's keys and values urlencoded as the request body. // // Caller should close r.Body when done reading from it. -func (c *Client) PostForm(url string, data map[string]string) (r *Response, err os.Error) { - var req Request - req.Method = "POST" - req.ProtoMajor = 1 - req.ProtoMinor = 1 - req.Close = true - body := urlencode(data) - req.Body = ioutil.NopCloser(body) - req.Header = Header{ - "Content-Type": {"application/x-www-form-urlencoded"}, - "Content-Length": {strconv.Itoa(body.Len())}, - } - req.ContentLength = int64(body.Len()) - - req.URL, err = ParseURL(url) - if err != nil { - return nil, err - } - - return send(&req, c.Transport) +func (c *Client) PostForm(url string, data Values) (r *Response, err os.Error) { + return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) } -// TODO: remove this function when PostForm takes a multimap. -func urlencode(data map[string]string) (b *bytes.Buffer) { - m := make(map[string][]string, len(data)) - for k, v := range data { - m[k] = []string{v} - } - return bytes.NewBuffer([]byte(EncodeQuery(m))) -} - -// Head issues a HEAD to the specified URL. +// Head issues a HEAD to the specified URL. If the response is one of the +// following redirect codes, Head follows the redirect after calling the +// Client's CheckRedirect function. +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) // // Head is a wrapper around DefaultClient.Head func Head(url string) (r *Response, err os.Error) { return DefaultClient.Head(url) } -// Head issues a HEAD to the specified URL. +// Head issues a HEAD to the specified URL. If the response is one of the +// following redirect codes, Head follows the redirect after calling the +// Client's CheckRedirect function. +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) func (c *Client) Head(url string) (r *Response, err os.Error) { - var req Request - req.Method = "HEAD" - if req.URL, err = ParseURL(url); err != nil { - return + req, err := NewRequest("HEAD", url, nil) + if err != nil { + return nil, err } - return send(&req, c.Transport) + return c.doFollowingRedirects(req) } diff --git a/src/pkg/http/client_test.go b/src/pkg/http/client_test.go index 59d62c1c9..9ef81d9d4 100644 --- a/src/pkg/http/client_test.go +++ b/src/pkg/http/client_test.go @@ -10,6 +10,7 @@ import ( "fmt" . "http" "http/httptest" + "io" "io/ioutil" "os" "strconv" @@ -26,7 +27,7 @@ func TestClient(t *testing.T) { ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() - r, _, err := Get(ts.URL) + r, err := Get(ts.URL) var b []byte if err == nil { b, err = ioutil.ReadAll(r.Body) @@ -77,6 +78,71 @@ func TestGetRequestFormat(t *testing.T) { } } +func TestPostRequestFormat(t *testing.T) { + tr := &recordingTransport{} + client := &Client{Transport: tr} + + url := "http://dummy.faketld/" + json := `{"key":"value"}` + b := strings.NewReader(json) + client.Post(url, "application/json", b) // Note: doesn't hit network + + if tr.req.Method != "POST" { + t.Errorf("got method %q, want %q", tr.req.Method, "POST") + } + if tr.req.URL.String() != url { + t.Errorf("got URL %q, want %q", tr.req.URL.String(), url) + } + if tr.req.Header == nil { + t.Fatalf("expected non-nil request Header") + } + if tr.req.Close { + t.Error("got Close true, want false") + } + if g, e := tr.req.ContentLength, int64(len(json)); g != e { + t.Errorf("got ContentLength %d, want %d", g, e) + } +} + +func TestPostFormRequestFormat(t *testing.T) { + tr := &recordingTransport{} + client := &Client{Transport: tr} + + url := "http://dummy.faketld/" + form := make(Values) + form.Set("foo", "bar") + form.Add("foo", "bar2") + form.Set("bar", "baz") + client.PostForm(url, form) // Note: doesn't hit network + + if tr.req.Method != "POST" { + t.Errorf("got method %q, want %q", tr.req.Method, "POST") + } + if tr.req.URL.String() != url { + t.Errorf("got URL %q, want %q", tr.req.URL.String(), url) + } + if tr.req.Header == nil { + t.Fatalf("expected non-nil request Header") + } + if g, e := tr.req.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; g != e { + t.Errorf("got Content-Type %q, want %q", g, e) + } + if tr.req.Close { + t.Error("got Close true, want false") + } + expectedBody := "foo=bar&foo=bar2&bar=baz" + if g, e := tr.req.ContentLength, int64(len(expectedBody)); g != e { + t.Errorf("got ContentLength %d, want %d", g, e) + } + bodyb, err := ioutil.ReadAll(tr.req.Body) + if err != nil { + t.Fatalf("ReadAll on req.Body: %v", err) + } + if g := string(bodyb); g != expectedBody { + t.Errorf("got body %q, want %q", g, expectedBody) + } +} + func TestRedirects(t *testing.T) { var ts *httptest.Server ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -96,9 +162,22 @@ func TestRedirects(t *testing.T) { defer ts.Close() c := &Client{} - _, _, err := c.Get(ts.URL) + _, err := c.Get(ts.URL) if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { - t.Errorf("with default client, expected error %q, got %q", e, g) + t.Errorf("with default client Get, expected error %q, got %q", e, g) + } + + // HEAD request should also have the ability to follow redirects. + _, err = c.Head(ts.URL) + if e, g := "Head /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client Head, expected error %q, got %q", e, g) + } + + // Do should also follow redirects. + greq, _ := NewRequest("GET", ts.URL, nil) + _, err = c.Do(greq) + if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client Do, expected error %q, got %q", e, g) } var checkErr os.Error @@ -107,7 +186,8 @@ func TestRedirects(t *testing.T) { lastVia = via return checkErr }} - _, finalUrl, err := c.Get(ts.URL) + res, err := c.Get(ts.URL) + finalUrl := res.Request.URL.String() if e, g := "<nil>", fmt.Sprintf("%v", err); e != g { t.Errorf("with custom client, expected error %q, got %q", e, g) } @@ -119,8 +199,47 @@ func TestRedirects(t *testing.T) { } checkErr = os.NewError("no redirects allowed") - _, finalUrl, err = c.Get(ts.URL) + res, err = c.Get(ts.URL) + finalUrl = res.Request.URL.String() if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g { t.Errorf("with redirects forbidden, expected error %q, got %q", e, g) } } + +func TestStreamingGet(t *testing.T) { + say := make(chan string) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.(Flusher).Flush() + for str := range say { + w.Write([]byte(str)) + w.(Flusher).Flush() + } + })) + defer ts.Close() + + c := &Client{} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + var buf [10]byte + for _, str := range []string{"i", "am", "also", "known", "as", "comet"} { + say <- str + n, err := io.ReadFull(res.Body, buf[0:len(str)]) + if err != nil { + t.Fatalf("ReadFull on %q: %v", str, err) + } + if n != len(str) { + t.Fatalf("Receiving %q, only read %d bytes", str, n) + } + got := string(buf[0:n]) + if got != str { + t.Fatalf("Expected %q, got %q", str, got) + } + } + close(say) + _, err = io.ReadFull(res.Body, buf[0:1]) + if err != os.EOF { + t.Fatalf("at end expected EOF, got %v", err) + } +} diff --git a/src/pkg/http/cookie.go b/src/pkg/http/cookie.go index 2c01826a1..eb61a7001 100644 --- a/src/pkg/http/cookie.go +++ b/src/pkg/http/cookie.go @@ -15,9 +15,9 @@ import ( "time" ) -// This implementation is done according to IETF draft-ietf-httpstate-cookie-23, found at +// This implementation is done according to RFC 6265: // -// http://tools.ietf.org/html/draft-ietf-httpstate-cookie-23 +// http://tools.ietf.org/html/rfc6265 // A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an // HTTP response or the Cookie header of an HTTP request. @@ -81,12 +81,17 @@ func readSetCookies(h Header) []*Cookie { if j := strings.Index(attr, "="); j >= 0 { attr, val = attr[:j], attr[j+1:] } - val, success = parseCookieValue(val) + lowerAttr := strings.ToLower(attr) + parseCookieValueFn := parseCookieValue + if lowerAttr == "expires" { + parseCookieValueFn = parseCookieExpiresValue + } + val, success = parseCookieValueFn(val) if !success { c.Unparsed = append(c.Unparsed, parts[i]) continue } - switch strings.ToLower(attr) { + switch lowerAttr { case "secure": c.Secure = true continue @@ -112,8 +117,11 @@ func readSetCookies(h Header) []*Cookie { c.RawExpires = val exptime, err := time.Parse(time.RFC1123, val) if err != nil { - c.Expires = time.Time{} - break + exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", val) + if err != nil { + c.Expires = time.Time{} + break + } } c.Expires = *exptime continue @@ -130,6 +138,37 @@ func readSetCookies(h Header) []*Cookie { return cookies } +// SetCookie adds a Set-Cookie header to the provided ResponseWriter's headers. +func SetCookie(w ResponseWriter, cookie *Cookie) { + var b bytes.Buffer + writeSetCookieToBuffer(&b, cookie) + w.Header().Add("Set-Cookie", b.String()) +} + +func writeSetCookieToBuffer(buf *bytes.Buffer, c *Cookie) { + fmt.Fprintf(buf, "%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) + if len(c.Path) > 0 { + fmt.Fprintf(buf, "; Path=%s", sanitizeValue(c.Path)) + } + if len(c.Domain) > 0 { + fmt.Fprintf(buf, "; Domain=%s", sanitizeValue(c.Domain)) + } + if len(c.Expires.Zone) > 0 { + fmt.Fprintf(buf, "; Expires=%s", c.Expires.Format(time.RFC1123)) + } + if c.MaxAge > 0 { + fmt.Fprintf(buf, "; Max-Age=%d", c.MaxAge) + } else if c.MaxAge < 0 { + fmt.Fprintf(buf, "; Max-Age=0") + } + if c.HttpOnly { + fmt.Fprintf(buf, "; HttpOnly") + } + if c.Secure { + fmt.Fprintf(buf, "; Secure") + } +} + // writeSetCookies writes the wire representation of the set-cookies // to w. Each cookie is written on a separate "Set-Cookie: " line. // This choice is made because HTTP parsers tend to have a limit on @@ -142,27 +181,7 @@ func writeSetCookies(w io.Writer, kk []*Cookie) os.Error { var b bytes.Buffer for _, c := range kk { b.Reset() - fmt.Fprintf(&b, "%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) - if len(c.Path) > 0 { - fmt.Fprintf(&b, "; Path=%s", sanitizeValue(c.Path)) - } - if len(c.Domain) > 0 { - fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(c.Domain)) - } - if len(c.Expires.Zone) > 0 { - fmt.Fprintf(&b, "; Expires=%s", c.Expires.Format(time.RFC1123)) - } - if c.MaxAge > 0 { - fmt.Fprintf(&b, "; Max-Age=%d", c.MaxAge) - } else if c.MaxAge < 0 { - fmt.Fprintf(&b, "; Max-Age=0") - } - if c.HttpOnly { - fmt.Fprintf(&b, "; HttpOnly") - } - if c.Secure { - fmt.Fprintf(&b, "; Secure") - } + writeSetCookieToBuffer(&b, c) lines = append(lines, "Set-Cookie: "+b.String()+"\r\n") } sort.SortStrings(lines) @@ -218,22 +237,26 @@ func readCookies(h Header) []*Cookie { return cookies } -// writeCookies writes the wire representation of the cookies -// to w. Each cookie is written on a separate "Cookie: " line. -// This choice is made because HTTP parsers tend to have a limit on -// line-length, so it seems safer to place cookies on separate lines. +// writeCookies writes the wire representation of the cookies to +// w. According to RFC 6265 section 5.4, writeCookies does not +// attach more than one Cookie header field. That means all +// cookies, if any, are written into the same line, separated by +// semicolon. func writeCookies(w io.Writer, kk []*Cookie) os.Error { - lines := make([]string, 0, len(kk)) - for _, c := range kk { - lines = append(lines, fmt.Sprintf("Cookie: %s=%s\r\n", sanitizeName(c.Name), sanitizeValue(c.Value))) + if len(kk) == 0 { + return nil } - sort.SortStrings(lines) - for _, l := range lines { - if _, err := io.WriteString(w, l); err != nil { - return err + var buf bytes.Buffer + fmt.Fprintf(&buf, "Cookie: ") + for i, c := range kk { + if i > 0 { + fmt.Fprintf(&buf, "; ") } + fmt.Fprintf(&buf, "%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) } - return nil + fmt.Fprintf(&buf, "\r\n") + _, err := w.Write(buf.Bytes()) + return err } func sanitizeName(n string) string { @@ -257,7 +280,7 @@ func unquoteCookieValue(v string) string { } func isCookieByte(c byte) bool { - switch true { + switch { case c == 0x21, 0x23 <= c && c <= 0x2b, 0x2d <= c && c <= 0x3a, 0x3c <= c && c <= 0x5b, 0x5d <= c && c <= 0x7e: return true @@ -265,10 +288,22 @@ func isCookieByte(c byte) bool { return false } +func isCookieExpiresByte(c byte) (ok bool) { + return isCookieByte(c) || c == ',' || c == ' ' +} + func parseCookieValue(raw string) (string, bool) { + return parseCookieValueUsing(raw, isCookieByte) +} + +func parseCookieExpiresValue(raw string) (string, bool) { + return parseCookieValueUsing(raw, isCookieExpiresByte) +} + +func parseCookieValueUsing(raw string, validByte func(byte) bool) (string, bool) { raw = unquoteCookieValue(raw) for i := 0; i < len(raw); i++ { - if !isCookieByte(raw[i]) { + if !validByte(raw[i]) { return "", false } } diff --git a/src/pkg/http/cookie_test.go b/src/pkg/http/cookie_test.go index a3ae85cd6..02e42226b 100644 --- a/src/pkg/http/cookie_test.go +++ b/src/pkg/http/cookie_test.go @@ -8,11 +8,12 @@ import ( "bytes" "fmt" "json" + "os" "reflect" "testing" + "time" ) - var writeSetCookiesTests = []struct { Cookies []*Cookie Raw string @@ -43,14 +44,55 @@ func TestWriteSetCookies(t *testing.T) { } } +type headerOnlyResponseWriter Header + +func (ho headerOnlyResponseWriter) Header() Header { + return Header(ho) +} + +func (ho headerOnlyResponseWriter) Write([]byte) (int, os.Error) { + panic("NOIMPL") +} + +func (ho headerOnlyResponseWriter) WriteHeader(int) { + panic("NOIMPL") +} + +func TestSetCookie(t *testing.T) { + m := make(Header) + SetCookie(headerOnlyResponseWriter(m), &Cookie{Name: "cookie-1", Value: "one", Path: "/restricted/"}) + SetCookie(headerOnlyResponseWriter(m), &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600}) + if l := len(m["Set-Cookie"]); l != 2 { + t.Fatalf("expected %d cookies, got %d", 2, l) + } + if g, e := m["Set-Cookie"][0], "cookie-1=one; Path=/restricted/"; g != e { + t.Errorf("cookie #1: want %q, got %q", e, g) + } + if g, e := m["Set-Cookie"][1], "cookie-2=two; Max-Age=3600"; g != e { + t.Errorf("cookie #2: want %q, got %q", e, g) + } +} + var writeCookiesTests = []struct { Cookies []*Cookie Raw string }{ { + []*Cookie{}, + "", + }, + { []*Cookie{&Cookie{Name: "cookie-1", Value: "v$1"}}, "Cookie: cookie-1=v$1\r\n", }, + { + []*Cookie{ + &Cookie{Name: "cookie-1", Value: "v$1"}, + &Cookie{Name: "cookie-2", Value: "v$2"}, + &Cookie{Name: "cookie-3", Value: "v$3"}, + }, + "Cookie: cookie-1=v$1; cookie-2=v$2; cookie-3=v$3\r\n", + }, } func TestWriteCookies(t *testing.T) { @@ -73,6 +115,19 @@ var readSetCookiesTests = []struct { Header{"Set-Cookie": {"Cookie-1=v$1"}}, []*Cookie{&Cookie{Name: "Cookie-1", Value: "v$1", Raw: "Cookie-1=v$1"}}, }, + { + Header{"Set-Cookie": {"NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly"}}, + []*Cookie{&Cookie{ + Name: "NID", + Value: "99=YsDT5i3E-CXax-", + Path: "/", + Domain: ".google.ch", + HttpOnly: true, + Expires: time.Time{Year: 2011, Month: 11, Day: 23, Hour: 1, Minute: 5, Second: 3, Weekday: 3, ZoneOffset: 0, Zone: "GMT"}, + RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT", + Raw: "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly", + }}, + }, } func toJSON(v interface{}) string { diff --git a/src/pkg/http/fcgi/child.go b/src/pkg/http/fcgi/child.go index 114052bee..19718824c 100644 --- a/src/pkg/http/fcgi/child.go +++ b/src/pkg/http/fcgi/child.go @@ -9,11 +9,10 @@ package fcgi import ( "fmt" "http" + "http/cgi" "io" "net" "os" - "strconv" - "strings" "time" ) @@ -38,68 +37,6 @@ func newRequest(reqId uint16, flags uint8) *request { return r } -// TODO(eds): copied from http/cgi -var skipHeader = map[string]bool{ - "HTTP_HOST": true, - "HTTP_REFERER": true, - "HTTP_USER_AGENT": true, -} - -// httpRequest converts r to an http.Request. -// TODO(eds): this is very similar to http/cgi's requestFromEnvironment -func (r *request) httpRequest(body io.ReadCloser) (*http.Request, os.Error) { - req := &http.Request{ - Method: r.params["REQUEST_METHOD"], - RawURL: r.params["REQUEST_URI"], - Body: body, - Header: http.Header{}, - Trailer: http.Header{}, - Proto: r.params["SERVER_PROTOCOL"], - } - - var ok bool - req.ProtoMajor, req.ProtoMinor, ok = http.ParseHTTPVersion(req.Proto) - if !ok { - return nil, os.NewError("fcgi: invalid HTTP version") - } - - req.Host = r.params["HTTP_HOST"] - req.Referer = r.params["HTTP_REFERER"] - req.UserAgent = r.params["HTTP_USER_AGENT"] - - if lenstr := r.params["CONTENT_LENGTH"]; lenstr != "" { - clen, err := strconv.Atoi64(r.params["CONTENT_LENGTH"]) - if err != nil { - return nil, os.NewError("fcgi: bad CONTENT_LENGTH parameter: " + lenstr) - } - req.ContentLength = clen - } - - if req.Host != "" { - req.RawURL = "http://" + req.Host + r.params["REQUEST_URI"] - url, err := http.ParseURL(req.RawURL) - if err != nil { - return nil, os.NewError("fcgi: failed to parse host and REQUEST_URI into a URL: " + req.RawURL) - } - req.URL = url - } - if req.URL == nil { - req.RawURL = r.params["REQUEST_URI"] - url, err := http.ParseURL(req.RawURL) - if err != nil { - return nil, os.NewError("fcgi: failed to parse REQUEST_URI into a URL: " + req.RawURL) - } - req.URL = url - } - - for key, val := range r.params { - if strings.HasPrefix(key, "HTTP_") && !skipHeader[key] { - req.Header.Add(strings.Replace(key[5:], "_", "-", -1), val) - } - } - return req, nil -} - // parseParams reads an encoded []byte into Params. func (r *request) parseParams() { text := r.rawParams @@ -169,15 +106,7 @@ func (r *response) WriteHeader(code int) { } fmt.Fprintf(r.w, "Status: %d %s\r\n", code, http.StatusText(code)) - // TODO(eds): this is duplicated in http and http/cgi - for k, vv := range r.header { - for _, v := range vv { - v = strings.Replace(v, "\n", "", -1) - v = strings.Replace(v, "\r", "", -1) - v = strings.TrimSpace(v) - fmt.Fprintf(r.w, "%s: %s\r\n", k, v) - } - } + r.header.Write(r.w) r.w.WriteString("\r\n") } @@ -281,12 +210,13 @@ func (c *child) serve() { func (c *child) serveRequest(req *request, body io.ReadCloser) { r := newResponse(c, req) - httpReq, err := req.httpRequest(body) + httpReq, err := cgi.RequestFromMap(req.params) if err != nil { // there was an error reading the request r.WriteHeader(http.StatusInternalServerError) c.conn.writeRecord(typeStderr, req.reqId, []byte(err.String())) } else { + httpReq.Body = body c.handler.ServeHTTP(r, httpReq) } if body != nil { diff --git a/src/pkg/http/fs.go b/src/pkg/http/fs.go index 17d5297b8..28a0c51ef 100644 --- a/src/pkg/http/fs.go +++ b/src/pkg/http/fs.go @@ -175,7 +175,9 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { } w.Header().Set("Accept-Ranges", "bytes") - w.Header().Set("Content-Length", strconv.Itoa64(size)) + if w.Header().Get("Content-Encoding") == "" { + w.Header().Set("Content-Length", strconv.Itoa64(size)) + } w.WriteHeader(code) diff --git a/src/pkg/http/fs_test.go b/src/pkg/http/fs_test.go index 09d0981f2..554053449 100644 --- a/src/pkg/http/fs_test.go +++ b/src/pkg/http/fs_test.go @@ -96,12 +96,12 @@ func TestServeFileContentType(t *testing.T) { })) defer ts.Close() get := func(want string) { - resp, _, err := Get(ts.URL) + resp, err := Get(ts.URL) if err != nil { t.Fatal(err) } if h := resp.Header.Get("Content-Type"); h != want { - t.Errorf("Content-Type mismatch: got %q, want %q", h, want) + t.Errorf("Content-Type mismatch: got %d, want %d", h, want) } } get("text/plain; charset=utf-8") @@ -109,6 +109,21 @@ func TestServeFileContentType(t *testing.T) { get(ctype) } +func TestServeFileWithContentEncoding(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "foo") + ServeFile(w, r, "testdata/file") + })) + defer ts.Close() + resp, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if g, e := resp.ContentLength, int64(-1); g != e { + t.Errorf("Content-Length mismatch: got %q, want %q", g, e) + } +} + func getBody(t *testing.T, req Request) (*Response, []byte) { r, err := DefaultClient.Do(&req) if err != nil { diff --git a/src/pkg/http/header.go b/src/pkg/http/header.go index 95b0f3db6..95140b01f 100644 --- a/src/pkg/http/header.go +++ b/src/pkg/http/header.go @@ -4,7 +4,14 @@ package http -import "net/textproto" +import ( + "fmt" + "io" + "net/textproto" + "os" + "sort" + "strings" +) // A Header represents the key-value pairs in an HTTP header. type Header map[string][]string @@ -35,6 +42,37 @@ func (h Header) Del(key string) { textproto.MIMEHeader(h).Del(key) } +// Write writes a header in wire format. +func (h Header) Write(w io.Writer) os.Error { + return h.WriteSubset(w, nil) +} + +// WriteSubset writes a header in wire format. +// If exclude is not nil, keys where exclude[key] == true are not written. +func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) os.Error { + keys := make([]string, 0, len(h)) + for k := range h { + if exclude == nil || !exclude[k] { + keys = append(keys, k) + } + } + sort.SortStrings(keys) + for _, k := range keys { + for _, v := range h[k] { + v = strings.Replace(v, "\n", " ", -1) + v = strings.Replace(v, "\r", " ", -1) + v = strings.TrimSpace(v) + if v == "" { + continue + } + if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { + return err + } + } + } + return nil +} + // CanonicalHeaderKey returns the canonical format of the // header key s. The canonicalization converts the first // letter and any letter following a hyphen to upper case; diff --git a/src/pkg/http/header_test.go b/src/pkg/http/header_test.go new file mode 100644 index 000000000..7e24cb069 --- /dev/null +++ b/src/pkg/http/header_test.go @@ -0,0 +1,71 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "testing" +) + +var headerWriteTests = []struct { + h Header + exclude map[string]bool + expected string +}{ + {Header{}, nil, ""}, + { + Header{ + "Content-Type": {"text/html; charset=UTF-8"}, + "Content-Length": {"0"}, + }, + nil, + "Content-Length: 0\r\nContent-Type: text/html; charset=UTF-8\r\n", + }, + { + Header{ + "Content-Length": {"0", "1", "2"}, + }, + nil, + "Content-Length: 0\r\nContent-Length: 1\r\nContent-Length: 2\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true}, + "Content-Encoding: gzip\r\nExpires: -1\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0", "1", "2"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true}, + "Content-Encoding: gzip\r\nExpires: -1\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true, "Expires": true, "Content-Encoding": true}, + "", + }, +} + +func TestHeaderWrite(t *testing.T) { + var buf bytes.Buffer + for i, test := range headerWriteTests { + test.h.WriteSubset(&buf, test.exclude) + if buf.String() != test.expected { + t.Errorf("#%d:\n got: %q\nwant: %q", i, buf.String(), test.expected) + } + buf.Reset() + } +} diff --git a/src/pkg/http/httptest/server.go b/src/pkg/http/httptest/server.go index 8e385d045..879f04f33 100644 --- a/src/pkg/http/httptest/server.go +++ b/src/pkg/http/httptest/server.go @@ -108,29 +108,24 @@ func (s *Server) CloseClientConnections() { // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end // of ASN.1 time). var localhostCert = []byte(`-----BEGIN CERTIFICATE----- -MIIBwTCCASugAwIBAgIBADALBgkqhkiG9w0BAQUwADAeFw0xMTAzMzEyMDI1MDda -Fw00OTEyMzEyMzU5NTlaMAAwggCdMAsGCSqGSIb3DQEBAQOCAIwAMIIAhwKCAIB6 -oy4iT42G6qk+GGn5VL5JlnJT6ZG5cqaMNFaNGlIxNb6CPUZLKq2sM3gRaimsktIw -nNAcNwQGHpe1tZo+J/Pl04JTt71Y/TTAxy7OX27aZf1Rpt0SjdZ7vTPnFDPNsHGe -KBKvPt55l2+YKjkZmV7eRevsVbpkNvNGB+T5d4Ge/wIBA6NPME0wDgYDVR0PAQH/ -BAQDAgCgMA0GA1UdDgQGBAQBAgMEMA8GA1UdIwQIMAaABAECAwQwGwYDVR0RBBQw -EoIJMTI3LjAuMC4xggVbOjoxXTALBgkqhkiG9w0BAQUDggCBAHC3gbdvc44vs+wD -g2kONiENnx8WKc0UTGg/TOXS3gaRb+CUIQtHWja65l8rAfclEovjHgZ7gx8brO0W -JuC6p3MUAKsgOssIrrRIx2rpnfcmFVMzguCmrMNVmKUAalw18Yp0F72xYAIitVQl -kJrLdIhBajcJRYu/YGltHQRaXuVt +MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX +DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7 +qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL +8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw +DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG +CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7 ++l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM= -----END CERTIFICATE----- `) // localhostKey is the private key for localhostCert. var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- -MIIBkgIBAQKCAIB6oy4iT42G6qk+GGn5VL5JlnJT6ZG5cqaMNFaNGlIxNb6CPUZL -Kq2sM3gRaimsktIwnNAcNwQGHpe1tZo+J/Pl04JTt71Y/TTAxy7OX27aZf1Rpt0S -jdZ7vTPnFDPNsHGeKBKvPt55l2+YKjkZmV7eRevsVbpkNvNGB+T5d4Ge/wIBAwKC -AIBRwh7Bil5Z8cYpZZv7jdQxDvbim7Z7ocRdeDmzZuF2I9RW04QyHHPIIlALnBvI -YeF1veASz1gEFGUjzmbUGqKYSbCoTzXoev+F4bmbRxcX9sOmtslqvhMSHRSzA5NH -aDVI3Hn4wvBVD8gePu8ACWqvPGbCiql11OKCMfjlPn2uuwJAx/24/F5DjXZ6hQQ7 -HxScOxKrpx5WnA9r1wZTltOTZkhRRzuLc21WJeE3M15QUdWi3zZxCKRFoth65HEs -jy9YHQJAnPueRI44tz79b5QqVbeaOMUr7ZCb1Kp0uo6G+ANPLdlfliAupwij2eIz -mHRJOWk0jBtXfRft1McH2H51CpXAyw== +MIIBPQIBAAJBALLgOZgBTI+kO6qAc3LysyKuJM7k+XqUqdgJHEH8gR5uytd1rO7v +tG+VW/YKk3+XAIiCnK7a11apC/ItVEBegM8CAwEAAQJBAI5sxq7naeR9ahyqRkJi +SIv2iMxLuPEHaezf5CYOPWjSjBPyVhyRevkhtqEjF/WkgL7C2nWpYHsUcBDBQVF0 +3KECIQDtEGB2ulnkZAahl3WuJziXGLB+p8Wgx7wzSM6bHu1c6QIhAMEp++CaS+SJ +/TrU0zwY/fW4SvQeb49BPZUF3oqR8Xz3AiEA1rAJHBzBgdOQKdE3ksMUPcnvNJSN +poCcELmz2clVXtkCIQCLytuLV38XHToTipR4yMl6O+6arzAjZ56uq7m7ZRV0TwIh +AM65XAOw8Dsg9Kq78aYXiOEDc5DL0sbFUu/SlmRcCg93 -----END RSA PRIVATE KEY----- `) diff --git a/src/pkg/http/persist.go b/src/pkg/http/persist.go index e4eea6815..62f9ff1b5 100644 --- a/src/pkg/http/persist.go +++ b/src/pkg/http/persist.go @@ -111,7 +111,7 @@ func (sc *ServerConn) Read() (req *Request, err os.Error) { // Make sure body is fully consumed, even if user does not call body.Close if lastbody != nil { // body.Close is assumed to be idempotent and multiple calls to - // it should return the error that its first invokation + // it should return the error that its first invocation // returned. err = lastbody.Close() if err != nil { @@ -222,7 +222,6 @@ type ClientConn struct { pipe textproto.Pipeline writeReq func(*Request, io.Writer) os.Error - readRes func(buf *bufio.Reader, method string) (*Response, os.Error) } // NewClientConn returns a new ClientConn reading and writing c. If r is not @@ -236,7 +235,6 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { r: r, pipereq: make(map[*Request]uint), writeReq: (*Request).Write, - readRes: ReadResponse, } } @@ -339,8 +337,13 @@ func (cc *ClientConn) Pending() int { // returned together with an ErrPersistEOF, which means that the remote // requested that this be the last request serviced. Read can be called // concurrently with Write, but not with another Read. -func (cc *ClientConn) Read(req *Request) (resp *Response, err os.Error) { +func (cc *ClientConn) Read(req *Request) (*Response, os.Error) { + return cc.readUsing(req, ReadResponse) +} +// readUsing is the implementation of Read with a replaceable +// ReadResponse-like function, used by the Transport. +func (cc *ClientConn) readUsing(req *Request, readRes func(*bufio.Reader, *Request) (*Response, os.Error)) (resp *Response, err os.Error) { // Retrieve the pipeline ID of this request/response pair cc.lk.Lock() id, ok := cc.pipereq[req] @@ -383,7 +386,7 @@ func (cc *ClientConn) Read(req *Request) (resp *Response, err os.Error) { } } - resp, err = cc.readRes(r, req.Method) + resp, err = readRes(r, req) cc.lk.Lock() defer cc.lk.Unlock() if err != nil { diff --git a/src/pkg/http/proxy_test.go b/src/pkg/http/proxy_test.go index 308bf44b4..9b320b3aa 100644 --- a/src/pkg/http/proxy_test.go +++ b/src/pkg/http/proxy_test.go @@ -40,10 +40,8 @@ func TestUseProxy(t *testing.T) { no_proxy := "foobar.com, .barbaz.net" os.Setenv("NO_PROXY", no_proxy) - tr := &Transport{} - for _, test := range UseProxyTests { - if tr.useProxy(test.host+":80") != test.match { + if useProxy(test.host+":80") != test.match { t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) } } diff --git a/src/pkg/http/readrequest_test.go b/src/pkg/http/readrequest_test.go index 19e2ff774..d93e573f5 100644 --- a/src/pkg/http/readrequest_test.go +++ b/src/pkg/http/readrequest_test.go @@ -64,7 +64,7 @@ var reqTests = []reqTest{ Host: "www.techcrunch.com", Referer: "", UserAgent: "Fake", - Form: map[string][]string{}, + Form: Values{}, }, "abcdef\n", @@ -99,7 +99,7 @@ var reqTests = []reqTest{ Host: "test", Referer: "", UserAgent: "", - Form: map[string][]string{}, + Form: Values{}, }, "", diff --git a/src/pkg/http/request.go b/src/pkg/http/request.go index 4852ca3e1..bdc3a7e4f 100644 --- a/src/pkg/http/request.go +++ b/src/pkg/http/request.go @@ -10,8 +10,10 @@ package http import ( "bufio" + "bytes" "crypto/tls" "container/vector" + "encoding/base64" "fmt" "io" "io/ioutil" @@ -88,10 +90,10 @@ type Request struct { // // then // - // Header = map[string]string{ - // "Accept-Encoding": "gzip, deflate", - // "Accept-Language": "en-us", - // "Connection": "keep-alive", + // Header = map[string][]string{ + // "Accept-Encoding": {"gzip, deflate"}, + // "Accept-Language": {"en-us"}, + // "Connection": {"keep-alive"}, // } // // HTTP defines that header names are case-insensitive. @@ -139,7 +141,7 @@ type Request struct { UserAgent string // The parsed form. Only available after ParseForm is called. - Form map[string][]string + Form Values // The parsed multipart form, including file uploads. // Only available after ParseMultipartForm is called. @@ -230,15 +232,15 @@ const defaultUserAgent = "Go http package" // Method (defaults to "GET") // UserAgent (defaults to defaultUserAgent) // Referer -// Header +// Header (only keys not already in this list) // Cookie // ContentLength // TransferEncoding // Body // -// If Body is present but Content-Length is <= 0, Write adds -// "Transfer-Encoding: chunked" to the header. Body is closed after -// it is sent. +// If Body is present, Content-Length is <= 0 and TransferEncoding +// hasn't been set to "identity", Write adds "Transfer-Encoding: +// chunked" to the header. Body is closed after it is sent. func (req *Request) Write(w io.Writer) os.Error { return req.write(w, false) } @@ -255,6 +257,9 @@ func (req *Request) WriteProxy(w io.Writer) os.Error { func (req *Request) write(w io.Writer, usingProxy bool) os.Error { host := req.Host if host == "" { + if req.URL == nil { + return os.NewError("http: Request.Write on Request with no Host or URL set") + } host = req.URL.Host } @@ -275,9 +280,7 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error { fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), uri) // Header lines - if !usingProxy { - fmt.Fprintf(w, "Host: %s\r\n", host) - } + fmt.Fprintf(w, "Host: %s\r\n", host) fmt.Fprintf(w, "User-Agent: %s\r\n", valueOrDefault(req.UserAgent, defaultUserAgent)) if req.Referer != "" { fmt.Fprintf(w, "Referer: %s\r\n", req.Referer) @@ -300,7 +303,7 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error { // from Request, and introduce Request methods along the lines of // Response.{GetHeader,AddHeader} and string constants for "Host", // "User-Agent" and "Referer". - err = writeSortedHeader(w, req.Header, reqExcludeHeader) + err = req.Header.WriteSubset(w, reqExcludeHeader) if err != nil { return err } @@ -476,9 +479,35 @@ func NewRequest(method, url string, body io.Reader) (*Request, os.Error) { Body: rc, Host: u.Host, } + if body != nil { + switch v := body.(type) { + case *strings.Reader: + req.ContentLength = int64(v.Len()) + case *bytes.Buffer: + req.ContentLength = int64(v.Len()) + default: + req.ContentLength = -1 // chunked + } + if req.ContentLength == 0 { + // To prevent chunking and disambiguate this + // from the default ContentLength zero value. + req.TransferEncoding = []string{"identity"} + } + } + return req, nil } +// SetBasicAuth sets the request's Authorization header to use HTTP +// Basic Authentication with the provided username and password. +// +// With HTTP Basic Authentication the provided username and password +// are not encrypted. +func (r *Request) SetBasicAuth(username, password string) { + s := username + ":" + password + r.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(s))) +} + // ReadRequest reads and parses a request from b. func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { @@ -573,18 +602,56 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { return req, nil } +// Values maps a string key to a list of values. +// It is typically used for query parameters and form values. +// Unlike in the Header map, the keys in a Values map +// are case-sensitive. +type Values map[string][]string + +// Get gets the first value associated with the given key. +// If there are no values associated with the key, Get returns +// the empty string. To access multiple values, use the map +// directly. +func (v Values) Get(key string) string { + if v == nil { + return "" + } + vs, ok := v[key] + if !ok || len(vs) == 0 { + return "" + } + return vs[0] +} + +// Set sets the key to value. It replaces any existing +// values. +func (v Values) Set(key, value string) { + v[key] = []string{value} +} + +// Add adds the key to value. It appends to any existing +// values associated with key. +func (v Values) Add(key, value string) { + v[key] = append(v[key], value) +} + +// Del deletes the values associated with key. +func (v Values) Del(key string) { + v[key] = nil, false +} + // ParseQuery parses the URL-encoded query string and returns // a map listing the values specified for each key. // ParseQuery always returns a non-nil map containing all the // valid query parameters found; err describes the first decoding error // encountered, if any. -func ParseQuery(query string) (m map[string][]string, err os.Error) { - m = make(map[string][]string) +func ParseQuery(query string) (m Values, err os.Error) { + m = make(Values) err = parseQuery(m, query) return } -func parseQuery(m map[string][]string, query string) (err os.Error) { +func parseQuery(m Values, query string) (err os.Error) { for _, kv := range strings.Split(query, "&", -1) { if len(kv) == 0 { continue @@ -617,7 +684,7 @@ func (r *Request) ParseForm() (err os.Error) { return } - r.Form = make(map[string][]string) + r.Form = make(Values) if r.URL != nil { err = parseQuery(r.Form, r.URL.RawQuery) } diff --git a/src/pkg/http/request_test.go b/src/pkg/http/request_test.go index f79d3a242..e03ed3b05 100644 --- a/src/pkg/http/request_test.go +++ b/src/pkg/http/request_test.go @@ -162,16 +162,25 @@ func TestRedirect(t *testing.T) { defer ts.Close() var end = regexp.MustCompile("/foo/$") - r, url, err := Get(ts.URL) + r, err := Get(ts.URL) if err != nil { t.Fatal(err) } r.Body.Close() + url := r.Request.URL.String() if r.StatusCode != 200 || !end.MatchString(url) { t.Fatalf("Get got status %d at %q, want 200 matching /foo/$", r.StatusCode, url) } } +func TestSetBasicAuth(t *testing.T) { + r, _ := NewRequest("GET", "http://example.com/", nil) + r.SetBasicAuth("Aladdin", "open sesame") + if g, e := r.Header.Get("Authorization"), "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="; g != e { + t.Errorf("got header %q, want %q", g, e) + } +} + func TestMultipartRequest(t *testing.T) { // Test that we can read the values and files of a // multipart request with FormValue and FormFile, @@ -213,13 +222,13 @@ func TestEmptyMultipartRequest(t *testing.T) { func testMissingFile(t *testing.T, req *Request) { f, fh, err := req.FormFile("missing") if f != nil { - t.Errorf("FormFile file = %q, want nil", f, nil) + t.Errorf("FormFile file = %q, want nil", f) } if fh != nil { - t.Errorf("FormFile file header = %q, want nil", fh, nil) + t.Errorf("FormFile file header = %q, want nil", fh) } if err != ErrMissingFile { - t.Errorf("FormFile err = %q, want nil", err, ErrMissingFile) + t.Errorf("FormFile err = %q, want ErrMissingFile", err) } } @@ -227,7 +236,7 @@ func newTestMultipartRequest(t *testing.T) *Request { b := bytes.NewBufferString(strings.Replace(message, "\n", "\r\n", -1)) req, err := NewRequest("POST", "/", b) if err != nil { - t.Fatalf("NewRequest:", err) + t.Fatal("NewRequest:", err) } ctype := fmt.Sprintf(`multipart/form-data; boundary="%s"`, boundary) req.Header.Set("Content-type", ctype) @@ -267,7 +276,7 @@ func validateTestMultipartContents(t *testing.T, req *Request, allMem bool) { func testMultipartFile(t *testing.T, req *Request, key, expectFilename, expectContent string) multipart.File { f, fh, err := req.FormFile(key) if err != nil { - t.Fatalf("FormFile(%q):", key, err) + t.Fatalf("FormFile(%q): %q", key, err) } if fh.Filename != expectFilename { t.Errorf("filename = %q, want %q", fh.Filename, expectFilename) diff --git a/src/pkg/http/requestwrite_test.go b/src/pkg/http/requestwrite_test.go index bb000c701..98fbcf459 100644 --- a/src/pkg/http/requestwrite_test.go +++ b/src/pkg/http/requestwrite_test.go @@ -69,6 +69,7 @@ var reqWriteTests = []reqWriteTest{ "Proxy-Connection: keep-alive\r\n\r\n", "GET http://www.techcrunch.com/ HTTP/1.1\r\n" + + "Host: www.techcrunch.com\r\n" + "User-Agent: Fake\r\n" + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" + "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" + @@ -101,6 +102,7 @@ var reqWriteTests = []reqWriteTest{ "6\r\nabcdef\r\n0\r\n\r\n", "GET http://www.google.com/search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + "6\r\nabcdef\r\n0\r\n\r\n", @@ -131,6 +133,7 @@ var reqWriteTests = []reqWriteTest{ "6\r\nabcdef\r\n0\r\n\r\n", "POST http://www.google.com/search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "Connection: close\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + @@ -164,6 +167,7 @@ var reqWriteTests = []reqWriteTest{ "abcdef", "POST http://www.google.com/search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "Connection: close\r\n" + "Content-Length: 6\r\n" + @@ -171,6 +175,35 @@ var reqWriteTests = []reqWriteTest{ "abcdef", }, + // HTTP/1.1 POST with Content-Length in headers + { + Request{ + Method: "POST", + RawURL: "http://example.com/", + Host: "example.com", + Header: Header{ + "Content-Length": []string{"10"}, // ignored + }, + ContentLength: 6, + }, + + []byte("abcdef"), + + "POST http://example.com/ HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + + "POST http://example.com/ HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + }, + // default to HTTP/1.1 { Request{ @@ -188,6 +221,7 @@ var reqWriteTests = []reqWriteTest{ // Looks weird but RawURL overrides what WriteProxy would choose. "GET /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "\r\n", }, @@ -240,13 +274,45 @@ func (rc *closeChecker) Close() os.Error { // TestRequestWriteClosesBody tests that Request.Write does close its request.Body. // It also indirectly tests NewRequest and that it doesn't wrap an existing Closer -// inside a NopCloser. +// inside a NopCloser, and that it serializes it correctly. func TestRequestWriteClosesBody(t *testing.T) { rc := &closeChecker{Reader: strings.NewReader("my body")} - req, _ := NewRequest("GET", "http://foo.com/", rc) + req, _ := NewRequest("POST", "http://foo.com/", rc) + if g, e := req.ContentLength, int64(-1); g != e { + t.Errorf("got req.ContentLength %d, want %d", g, e) + } buf := new(bytes.Buffer) req.Write(buf) if !rc.closed { t.Error("body not closed after write") } + if g, e := buf.String(), "POST / HTTP/1.1\r\nHost: foo.com\r\nUser-Agent: Go http package\r\nTransfer-Encoding: chunked\r\n\r\n7\r\nmy body\r\n0\r\n\r\n"; g != e { + t.Errorf("write:\n got: %s\nwant: %s", g, e) + } +} + +func TestZeroLengthNewRequest(t *testing.T) { + var buf bytes.Buffer + + // Writing with default identity encoding + req, _ := NewRequest("PUT", "http://foo.com/", strings.NewReader("")) + if len(req.TransferEncoding) == 0 || req.TransferEncoding[0] != "identity" { + t.Fatalf("got req.TransferEncoding of %v, want %v", req.TransferEncoding, []string{"identity"}) + } + if g, e := req.ContentLength, int64(0); g != e { + t.Errorf("got req.ContentLength %d, want %d", g, e) + } + req.Write(&buf) + if g, e := buf.String(), "PUT / HTTP/1.1\r\nHost: foo.com\r\nUser-Agent: Go http package\r\nContent-Length: 0\r\n\r\n"; g != e { + t.Errorf("identity write:\n got: %s\nwant: %s", g, e) + } + + // Overriding identity encoding and forcing chunked. + req, _ = NewRequest("PUT", "http://foo.com/", strings.NewReader("")) + req.TransferEncoding = nil + buf.Reset() + req.Write(&buf) + if g, e := buf.String(), "PUT / HTTP/1.1\r\nHost: foo.com\r\nUser-Agent: Go http package\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"; g != e { + t.Errorf("chunked write:\n got: %s\nwant: %s", g, e) + } } diff --git a/src/pkg/http/response.go b/src/pkg/http/response.go index 1f725ecdd..42e60c1f6 100644 --- a/src/pkg/http/response.go +++ b/src/pkg/http/response.go @@ -8,11 +8,9 @@ package http import ( "bufio" - "fmt" "io" "net/textproto" "os" - "sort" "strconv" "strings" ) @@ -32,10 +30,6 @@ type Response struct { ProtoMajor int // e.g. 1 ProtoMinor int // e.g. 0 - // RequestMethod records the method used in the HTTP request. - // Header fields such as Content-Length have method-specific meaning. - RequestMethod string // e.g. "HEAD", "CONNECT", "GET", etc. - // Header maps header keys to values. If the response had multiple // headers with the same key, they will be concatenated, with comma // delimiters. (Section 4.2 of RFC 2616 requires that multiple headers @@ -70,19 +64,26 @@ type Response struct { // Trailer maps trailer keys to values, in the same // format as the header. Trailer Header + + // The Request that was sent to obtain this Response. + // Request's Body is nil (having already been consumed). + // This is only populated for Client requests. + Request *Request } -// ReadResponse reads and returns an HTTP response from r. The RequestMethod -// parameter specifies the method used in the corresponding request (e.g., -// "GET", "HEAD"). Clients must call resp.Body.Close when finished reading -// resp.Body. After that call, clients can inspect resp.Trailer to find -// key/value pairs included in the response trailer. -func ReadResponse(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) { +// ReadResponse reads and returns an HTTP response from r. The +// req parameter specifies the Request that corresponds to +// this Response. Clients must call resp.Body.Close when finished +// reading resp.Body. After that call, clients can inspect +// resp.Trailer to find key/value pairs included in the response +// trailer. +func ReadResponse(r *bufio.Reader, req *Request) (resp *Response, err os.Error) { tp := textproto.NewReader(r) resp = new(Response) - resp.RequestMethod = strings.ToUpper(requestMethod) + resp.Request = req + resp.Request.Method = strings.ToUpper(resp.Request.Method) // Parse the first line of the response. line, err := tp.ReadLine() @@ -166,7 +167,9 @@ func (r *Response) ProtoAtLeast(major, minor int) bool { func (resp *Response) Write(w io.Writer) os.Error { // RequestMethod should be upper-case - resp.RequestMethod = strings.ToUpper(resp.RequestMethod) + if resp.Request != nil { + resp.Request.Method = strings.ToUpper(resp.Request.Method) + } // Status line text := resp.Status @@ -192,7 +195,7 @@ func (resp *Response) Write(w io.Writer) os.Error { } // Rest of header - err = writeSortedHeader(w, resp.Header, respExcludeHeader) + err = resp.Header.WriteSubset(w, respExcludeHeader) if err != nil { return err } @@ -213,27 +216,3 @@ func (resp *Response) Write(w io.Writer) os.Error { // Success return nil } - -func writeSortedHeader(w io.Writer, h Header, exclude map[string]bool) os.Error { - keys := make([]string, 0, len(h)) - for k := range h { - if exclude == nil || !exclude[k] { - keys = append(keys, k) - } - } - sort.SortStrings(keys) - for _, k := range keys { - for _, v := range h[k] { - v = strings.Replace(v, "\n", " ", -1) - v = strings.Replace(v, "\r", " ", -1) - v = strings.TrimSpace(v) - if v == "" { - continue - } - if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { - return err - } - } - } - return nil -} diff --git a/src/pkg/http/response_test.go b/src/pkg/http/response_test.go index 9e77c20c4..1d4a23423 100644 --- a/src/pkg/http/response_test.go +++ b/src/pkg/http/response_test.go @@ -23,6 +23,10 @@ type respTest struct { Body string } +func dummyReq(method string) *Request { + return &Request{Method: method} +} + var respTests = []respTest{ // Unchunked response without Content-Length. { @@ -32,12 +36,12 @@ var respTests = []respTest{ "Body here\n", Response{ - Status: "200 OK", - StatusCode: 200, - Proto: "HTTP/1.0", - ProtoMajor: 1, - ProtoMinor: 0, - RequestMethod: "GET", + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), Header: Header{ "Connection": {"close"}, // TODO(rsc): Delete? }, @@ -61,7 +65,7 @@ var respTests = []respTest{ Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, - RequestMethod: "GET", + Request: dummyReq("GET"), Close: true, ContentLength: -1, }, @@ -81,7 +85,7 @@ var respTests = []respTest{ Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, - RequestMethod: "GET", + Request: dummyReq("GET"), Close: false, ContentLength: 0, }, @@ -98,12 +102,12 @@ var respTests = []respTest{ "Body here\n", Response{ - Status: "200 OK", - StatusCode: 200, - Proto: "HTTP/1.0", - ProtoMajor: 1, - ProtoMinor: 0, - RequestMethod: "GET", + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), Header: Header{ "Connection": {"close"}, // TODO(rsc): Delete? "Content-Length": {"10"}, // TODO(rsc): Delete? @@ -133,7 +137,7 @@ var respTests = []respTest{ Proto: "HTTP/1.0", ProtoMajor: 1, ProtoMinor: 0, - RequestMethod: "GET", + Request: dummyReq("GET"), Header: Header{}, Close: true, ContentLength: -1, @@ -160,7 +164,7 @@ var respTests = []respTest{ Proto: "HTTP/1.0", ProtoMajor: 1, ProtoMinor: 0, - RequestMethod: "GET", + Request: dummyReq("GET"), Header: Header{}, Close: true, ContentLength: -1, // TODO(rsc): Fix? @@ -183,7 +187,7 @@ var respTests = []respTest{ Proto: "HTTP/1.0", ProtoMajor: 1, ProtoMinor: 0, - RequestMethod: "HEAD", + Request: dummyReq("HEAD"), Header: Header{}, Close: true, ContentLength: 0, @@ -199,12 +203,12 @@ var respTests = []respTest{ "\r\n", Response{ - Status: "200 OK", - StatusCode: 200, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - RequestMethod: "GET", + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), Header: Header{ "Content-Length": {"0"}, }, @@ -225,7 +229,7 @@ var respTests = []respTest{ Proto: "HTTP/1.0", ProtoMajor: 1, ProtoMinor: 0, - RequestMethod: "GET", + Request: dummyReq("GET"), Header: Header{}, Close: true, ContentLength: -1, @@ -244,7 +248,7 @@ var respTests = []respTest{ Proto: "HTTP/1.0", ProtoMajor: 1, ProtoMinor: 0, - RequestMethod: "GET", + Request: dummyReq("GET"), Header: Header{}, Close: true, ContentLength: -1, @@ -259,7 +263,7 @@ func TestReadResponse(t *testing.T) { tt := &respTests[i] var braw bytes.Buffer braw.WriteString(tt.Raw) - resp, err := ReadResponse(bufio.NewReader(&braw), tt.Resp.RequestMethod) + resp, err := ReadResponse(bufio.NewReader(&braw), tt.Resp.Request) if err != nil { t.Errorf("#%d: %s", i, err) continue @@ -340,7 +344,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) { buf.WriteString("Next Request Here") bufr := bufio.NewReader(&buf) - resp, err := ReadResponse(bufr, "GET") + resp, err := ReadResponse(bufr, dummyReq("GET")) checkErr(err, "ReadResponse") expectedLength := int64(-1) if !test.chunked { @@ -372,7 +376,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) { rest, err := ioutil.ReadAll(bufr) checkErr(err, "ReadAll on remainder") if e, g := "Next Request Here", string(rest); e != g { - fatalf("for chunked=%v remainder = %q, expected %q", g, e) + fatalf("remainder = %q, expected %q", g, e) } } } @@ -381,7 +385,7 @@ func diff(t *testing.T, prefix string, have, want interface{}) { hv := reflect.ValueOf(have).Elem() wv := reflect.ValueOf(want).Elem() if hv.Type() != wv.Type() { - t.Errorf("%s: type mismatch %v vs %v", prefix, hv.Type(), wv.Type()) + t.Errorf("%s: type mismatch %v want %v", prefix, hv.Type(), wv.Type()) } for i := 0; i < hv.NumField(); i++ { hf := hv.Field(i).Interface() diff --git a/src/pkg/http/responsewrite_test.go b/src/pkg/http/responsewrite_test.go index de0635da5..f8e63acf4 100644 --- a/src/pkg/http/responsewrite_test.go +++ b/src/pkg/http/responsewrite_test.go @@ -22,7 +22,7 @@ var respWriteTests = []respWriteTest{ StatusCode: 503, ProtoMajor: 1, ProtoMinor: 0, - RequestMethod: "GET", + Request: dummyReq("GET"), Header: Header{}, Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), ContentLength: 6, @@ -38,7 +38,7 @@ var respWriteTests = []respWriteTest{ StatusCode: 200, ProtoMajor: 1, ProtoMinor: 0, - RequestMethod: "GET", + Request: dummyReq("GET"), Header: Header{}, Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), ContentLength: -1, @@ -53,7 +53,7 @@ var respWriteTests = []respWriteTest{ StatusCode: 200, ProtoMajor: 1, ProtoMinor: 1, - RequestMethod: "GET", + Request: dummyReq("GET"), Header: Header{}, Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), ContentLength: 6, @@ -71,10 +71,10 @@ var respWriteTests = []respWriteTest{ // Also tests removal of leading and trailing whitespace. { Response{ - StatusCode: 204, - ProtoMajor: 1, - ProtoMinor: 1, - RequestMethod: "GET", + StatusCode: 204, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), Header: Header{ "Foo": []string{" Bar\nBaz "}, }, diff --git a/src/pkg/http/reverseproxy.go b/src/pkg/http/reverseproxy.go index e4ce1e34c..9a9e21599 100644 --- a/src/pkg/http/reverseproxy.go +++ b/src/pkg/http/reverseproxy.go @@ -92,6 +92,10 @@ func (p *ReverseProxy) ServeHTTP(rw ResponseWriter, req *Request) { } } + for _, cookie := range res.SetCookie { + SetCookie(rw, cookie) + } + rw.WriteHeader(res.StatusCode) if res.Body != nil { diff --git a/src/pkg/http/reverseproxy_test.go b/src/pkg/http/reverseproxy_test.go index 8cf7705d7..d7bcde90d 100644 --- a/src/pkg/http/reverseproxy_test.go +++ b/src/pkg/http/reverseproxy_test.go @@ -20,7 +20,11 @@ func TestReverseProxy(t *testing.T) { if r.Header.Get("X-Forwarded-For") == "" { t.Errorf("didn't get X-Forwarded-For header") } + if g, e := r.Host, "some-name"; g != e { + t.Errorf("backend got Host header %q, want %q", g, e) + } w.Header().Set("X-Foo", "bar") + SetCookie(w, &Cookie{Name: "flavor", Value: "chocolateChip"}) w.WriteHeader(backendStatus) w.Write([]byte(backendResponse)) })) @@ -33,7 +37,9 @@ func TestReverseProxy(t *testing.T) { frontend := httptest.NewServer(proxyHandler) defer frontend.Close() - res, _, err := Get(frontend.URL) + getReq, _ := NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + res, err := DefaultClient.Do(getReq) if err != nil { t.Fatalf("Get: %v", err) } @@ -43,6 +49,12 @@ func TestReverseProxy(t *testing.T) { if g, e := res.Header.Get("X-Foo"), "bar"; g != e { t.Errorf("got X-Foo %q; expected %q", g, e) } + if g, e := len(res.SetCookie), 1; g != e { + t.Fatalf("got %d SetCookies, want %d", g, e) + } + if cookie := res.SetCookie[0]; cookie.Name != "flavor" { + t.Errorf("unexpected cookie %q", cookie.Name) + } bodyBytes, _ := ioutil.ReadAll(res.Body) if g, e := string(bodyBytes), backendResponse; g != e { t.Errorf("got body %q; expected %q", g, e) diff --git a/src/pkg/http/serve_test.go b/src/pkg/http/serve_test.go index c3c7b8d33..dc4594a79 100644 --- a/src/pkg/http/serve_test.go +++ b/src/pkg/http/serve_test.go @@ -12,11 +12,14 @@ import ( "fmt" . "http" "http/httptest" + "io" "io/ioutil" + "log" "os" "net" "reflect" "strings" + "syscall" "testing" "time" ) @@ -252,7 +255,7 @@ func TestServerTimeouts(t *testing.T) { // Hit the HTTP server successfully. tr := &Transport{DisableKeepAlives: true} // they interfere with this test c := &Client{Transport: tr} - r, _, err := c.Get(url) + r, err := c.Get(url) if err != nil { t.Fatalf("http Get #1: %v", err) } @@ -282,7 +285,7 @@ func TestServerTimeouts(t *testing.T) { // Hit the HTTP server successfully again, verifying that the // previous slow connection didn't run our handler. (that we // get "req=2", not "req=3") - r, _, err = Get(url) + r, err = Get(url) if err != nil { t.Fatalf("http Get #2: %v", err) } @@ -323,7 +326,7 @@ func TestIdentityResponse(t *testing.T) { // responses. for _, te := range []string{"", "identity"} { url := ts.URL + "/?te=" + te - res, _, err := Get(url) + res, err := Get(url) if err != nil { t.Fatalf("error with Get of %s: %v", url, err) } @@ -342,7 +345,7 @@ func TestIdentityResponse(t *testing.T) { // Verify that ErrContentLength is returned url := ts.URL + "/?overwrite=1" - _, _, err := Get(url) + _, err := Get(url) if err != nil { t.Fatalf("error with Get of %s: %v", url, err) } @@ -389,7 +392,7 @@ func TestServeHTTP10Close(t *testing.T) { } r := bufio.NewReader(conn) - _, err = ReadResponse(r, "GET") + _, err = ReadResponse(r, &Request{Method: "GET"}) if err != nil { t.Fatal("ReadResponse error:", err) } @@ -417,7 +420,7 @@ func TestSetsRemoteAddr(t *testing.T) { })) defer ts.Close() - res, _, err := Get(ts.URL) + res, err := Get(ts.URL) if err != nil { t.Fatalf("Get error: %v", err) } @@ -432,13 +435,16 @@ func TestSetsRemoteAddr(t *testing.T) { } func TestChunkedResponseHeaders(t *testing.T) { + log.SetOutput(ioutil.Discard) // is noisy otherwise + defer log.SetOutput(os.Stderr) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted fmt.Fprintf(w, "I am a chunked response.") })) defer ts.Close() - res, _, err := Get(ts.URL) + res, err := Get(ts.URL) if err != nil { t.Fatalf("Get error: %v", err) } @@ -465,7 +471,7 @@ func Test304Responses(t *testing.T) { } })) defer ts.Close() - res, _, err := Get(ts.URL) + res, err := Get(ts.URL) if err != nil { t.Error(err) } @@ -490,6 +496,12 @@ func TestHeadResponses(t *testing.T) { if err != ErrBodyNotAllowed { t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) } + + // Also exercise the ReaderFrom path + _, err = io.Copy(w, strings.NewReader("Ignored body")) + if err != ErrBodyNotAllowed { + t.Errorf("on Copy, expected ErrBodyNotAllowed, got %v", err) + } })) defer ts.Close() res, err := Head(ts.URL) @@ -516,7 +528,7 @@ func TestTLSServer(t *testing.T) { if !strings.HasPrefix(ts.URL, "https://") { t.Fatalf("expected test TLS server to start with https://, got %q", ts.URL) } - res, _, err := Get(ts.URL) + res, err := Get(ts.URL) if err != nil { t.Error(err) } @@ -551,7 +563,7 @@ var serverExpectTests = []serverExpectTest{ {100, "", true, "200 OK"}, // 100-continue but requesting client to deny us, - // so it never eads the body. + // so it never reads the body. {100, "100-continue", false, "401 Unauthorized"}, // Likewise without 100-continue: {100, "", false, "401 Unauthorized"}, @@ -618,49 +630,29 @@ func TestServerExpect(t *testing.T) { } func TestServerConsumesRequestBody(t *testing.T) { - log := make(chan string, 100) - - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - log <- "got_request" - w.WriteHeader(StatusOK) - log <- "wrote_header" - })) - defer ts.Close() - - conn, err := net.Dial("tcp", ts.Listener.Addr().String()) - if err != nil { - t.Fatalf("Dial: %v", err) - } - defer conn.Close() - - bufr := bufio.NewReader(conn) - gotres := make(chan bool) - go func() { - line, err := bufr.ReadString('\n') - if err != nil { - t.Fatal(err) + conn := new(testConn) + body := strings.Repeat("x", 1<<20) + conn.readBuf.Write([]byte(fmt.Sprintf( + "POST / HTTP/1.1\r\n"+ + "Host: test\r\n"+ + "Content-Length: %d\r\n"+ + "\r\n", len(body)))) + conn.readBuf.Write([]byte(body)) + + done := make(chan bool) + + ls := &oneConnListener{conn} + go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { + if conn.readBuf.Len() < len(body)/2 { + t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len()) } - log <- line - gotres <- true - }() - - size := 1 << 20 - log <- "writing_request" - fmt.Fprintf(conn, "POST / HTTP/1.0\r\nContent-Length: %d\r\n\r\n", size) - time.Sleep(25e6) // give server chance to misbehave & speak out of turn - log <- "slept_after_req_headers" - conn.Write([]byte(strings.Repeat("a", size))) - - <-gotres - expected := []string{ - "writing_request", "got_request", - "slept_after_req_headers", "wrote_header", - "HTTP/1.0 200 OK\r\n"} - for step, e := range expected { - if g := <-log; e != g { - t.Errorf("on step %d expected %q, got %q", step, e, g) + rw.WriteHeader(200) + if g, e := conn.readBuf.Len(), 0; g != e { + t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e) } - } + done <- true + })) + <-done } func TestTimeoutHandler(t *testing.T) { @@ -677,7 +669,7 @@ func TestTimeoutHandler(t *testing.T) { // Succeed without timing out: sendHi <- true - res, _, err := Get(ts.URL) + res, err := Get(ts.URL) if err != nil { t.Error(err) } @@ -694,7 +686,7 @@ func TestTimeoutHandler(t *testing.T) { // Times out: timeout <- 1 - res, _, err = Get(ts.URL) + res, err = Get(ts.URL) if err != nil { t.Error(err) } @@ -713,3 +705,140 @@ func TestTimeoutHandler(t *testing.T) { t.Errorf("expected Write error of %v; got %v", e, g) } } + +// Verifies we don't path.Clean() on the wrong parts in redirects. +func TestRedirectMunging(t *testing.T) { + req, _ := NewRequest("GET", "http://example.com/", nil) + + resp := httptest.NewRecorder() + Redirect(resp, req, "/foo?next=http://bar.com/", 302) + if g, e := resp.Header().Get("Location"), "/foo?next=http://bar.com/"; g != e { + t.Errorf("Location header was %q; want %q", g, e) + } + + resp = httptest.NewRecorder() + Redirect(resp, req, "http://localhost:8080/_ah/login?continue=http://localhost:8080/", 302) + if g, e := resp.Header().Get("Location"), "http://localhost:8080/_ah/login?continue=http://localhost:8080/"; g != e { + t.Errorf("Location header was %q; want %q", g, e) + } +} + +// TestZeroLengthPostAndResponse exercises an optimization done by the Transport: +// when there is no body (either because the method doesn't permit a body, or an +// explicit Content-Length of zero is present), then the transport can re-use the +// connection immediately. But when it re-uses the connection, it typically closes +// the previous request's body, which is not optimal for zero-lengthed bodies, +// as the client would then see http.ErrBodyReadAfterClose and not 0, os.EOF. +func TestZeroLengthPostAndResponse(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + all, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatalf("handler ReadAll: %v", err) + } + if len(all) != 0 { + t.Errorf("handler got %d bytes; expected 0", len(all)) + } + rw.Header().Set("Content-Length", "0") + })) + defer ts.Close() + + req, err := NewRequest("POST", ts.URL, strings.NewReader("")) + if err != nil { + t.Fatal(err) + } + req.ContentLength = 0 + + var resp [5]*Response + for i := range resp { + resp[i], err = DefaultClient.Do(req) + if err != nil { + t.Fatalf("client post #%d: %v", i, err) + } + } + + for i := range resp { + all, err := ioutil.ReadAll(resp[i].Body) + if err != nil { + t.Fatalf("req #%d: client ReadAll: %v", i, err) + } + if len(all) != 0 { + t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all)) + } + } +} + +func TestHandlerPanic(t *testing.T) { + log.SetOutput(ioutil.Discard) // is noisy otherwise + defer log.SetOutput(os.Stderr) + + ts := httptest.NewServer(HandlerFunc(func(ResponseWriter, *Request) { + panic("intentional death for testing") + })) + defer ts.Close() + _, err := Get(ts.URL) + if err == nil { + t.Logf("expected an error") + } +} + +type errorListener struct { + errs []os.Error +} + +func (l *errorListener) Accept() (c net.Conn, err os.Error) { + if len(l.errs) == 0 { + return nil, os.EOF + } + err = l.errs[0] + l.errs = l.errs[1:] + return +} + +func (l *errorListener) Close() os.Error { + return nil +} + +func (l *errorListener) Addr() net.Addr { + return dummyAddr("test-address") +} + +func TestAcceptMaxFds(t *testing.T) { + log.SetOutput(ioutil.Discard) // is noisy otherwise + defer log.SetOutput(os.Stderr) + + ln := &errorListener{[]os.Error{ + &net.OpError{ + Op: "accept", + Error: os.Errno(syscall.EMFILE), + }}} + err := Serve(ln, HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {}))) + if err != os.EOF { + t.Errorf("got error %v, want EOF", err) + } +} + +func BenchmarkClientServer(b *testing.B) { + b.StopTimer() + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + fmt.Fprintf(rw, "Hello world.\n") + })) + defer ts.Close() + b.StartTimer() + + for i := 0; i < b.N; i++ { + res, err := Get(ts.URL) + if err != nil { + panic("Get: " + err.String()) + } + all, err := ioutil.ReadAll(res.Body) + if err != nil { + panic("ReadAll: " + err.String()) + } + body := string(all) + if body != "Hello world.\n" { + panic("Got body: " + body) + } + } + + b.StopTimer() +} diff --git a/src/pkg/http/server.go b/src/pkg/http/server.go index 96d2cb638..d4638f127 100644 --- a/src/pkg/http/server.go +++ b/src/pkg/http/server.go @@ -6,12 +6,12 @@ // TODO(rsc): // logging -// post support package http import ( "bufio" + "bytes" "crypto/rand" "crypto/tls" "fmt" @@ -20,6 +20,7 @@ import ( "net" "os" "path" + "runtime" "strconv" "strings" "sync" @@ -119,6 +120,27 @@ type response struct { closeAfterReply bool } +type writerOnly struct { + io.Writer +} + +func (r *response) ReadFrom(src io.Reader) (n int64, err os.Error) { + // Flush before checking r.chunking, as Flush will call + // WriteHeader if it hasn't been called yet, and WriteHeader + // is what sets r.chunking. + r.Flush() + if !r.chunking && r.bodyAllowed() { + if rf, ok := r.conn.rwc.(io.ReaderFrom); ok { + n, err = rf.ReadFrom(src) + r.written += n + return + } + } + // Fall back to default io.Copy implementation. + // Use wrapper to hide r.ReadFrom from io.Copy. + return io.Copy(writerOnly{r}, src) +} + // Create new connection from rwc. func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) { c = new(conn) @@ -309,10 +331,19 @@ func (w *response) WriteHeader(code int) { text = "status code " + codestring } io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n") - writeSortedHeader(w.conn.buf, w.header, nil) + w.header.Write(w.conn.buf) io.WriteString(w.conn.buf, "\r\n") } +// bodyAllowed returns true if a Write is allowed for this response type. +// It's illegal to call this before the header has been flushed. +func (w *response) bodyAllowed() bool { + if !w.wroteHeader { + panic("") + } + return w.status != StatusNotModified && w.req.Method != "HEAD" +} + func (w *response) Write(data []byte) (n int, err os.Error) { if w.conn.hijacked { log.Print("http: response.Write on hijacked connection") @@ -324,9 +355,7 @@ func (w *response) Write(data []byte) (n int, err os.Error) { if len(data) == 0 { return 0, nil } - - if w.status == StatusNotModified || w.req.Method == "HEAD" { - // Must not have body. + if !w.bodyAllowed() { return 0, ErrBodyNotAllowed } @@ -454,6 +483,33 @@ func (c *conn) close() { // Serve a new connection. func (c *conn) serve() { + defer func() { + err := recover() + if err == nil { + return + } + c.rwc.Close() + + // TODO(rsc,bradfitz): this is boilerplate. move it to runtime.Stack() + var buf bytes.Buffer + fmt.Fprintf(&buf, "http: panic serving %v: %v\n", c.remoteAddr, err) + for i := 1; i < 20; i++ { + pc, file, line, ok := runtime.Caller(i) + if !ok { + break + } + var name string + f := runtime.FuncForPC(pc) + if f != nil { + name = f.Name() + } else { + name = fmt.Sprintf("%#x", pc) + } + fmt.Fprintf(&buf, " %s %s:%d\n", name, file, line) + } + log.Print(buf.String()) + }() + for { w, err := c.readRequest() if err != nil { @@ -581,12 +637,18 @@ func Redirect(w ResponseWriter, r *Request, url string, code int) { url = olddir + url } + var query string + if i := strings.Index(url, "?"); i != -1 { + url, query = url[:i], url[i:] + } + // clean up but preserve trailing slash trailing := url[len(url)-1] == '/' url = path.Clean(url) if trailing && url[len(url)-1] != '/' { url += "/" } + url += query } } @@ -805,6 +867,10 @@ func (srv *Server) Serve(l net.Listener) os.Error { for { rw, e := l.Accept() if e != nil { + if ne, ok := e.(net.Error); ok && ne.Temporary() { + log.Printf("http: Accept error: %v", e) + continue + } return e } if srv.ReadTimeout != 0 { diff --git a/src/pkg/http/spdy/Makefile b/src/pkg/http/spdy/Makefile new file mode 100644 index 000000000..3bec220c4 --- /dev/null +++ b/src/pkg/http/spdy/Makefile @@ -0,0 +1,13 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=http/spdy +GOFILES=\ + read.go\ + types.go\ + write.go\ + +include ../../../Make.pkg diff --git a/src/pkg/http/spdy/read.go b/src/pkg/http/spdy/read.go new file mode 100644 index 000000000..159dbc578 --- /dev/null +++ b/src/pkg/http/spdy/read.go @@ -0,0 +1,287 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package spdy + +import ( + "compress/zlib" + "encoding/binary" + "http" + "io" + "os" + "strings" +) + +func (frame *SynStreamFrame) read(h ControlFrameHeader, f *Framer) os.Error { + return f.readSynStreamFrame(h, frame) +} + +func (frame *SynReplyFrame) read(h ControlFrameHeader, f *Framer) os.Error { + return f.readSynReplyFrame(h, frame) +} + +func (frame *RstStreamFrame) read(h ControlFrameHeader, f *Framer) os.Error { + frame.CFHeader = h + if err := binary.Read(f.r, binary.BigEndian, &frame.StreamId); err != nil { + return err + } + if err := binary.Read(f.r, binary.BigEndian, &frame.Status); err != nil { + return err + } + return nil +} + +func (frame *SettingsFrame) read(h ControlFrameHeader, f *Framer) os.Error { + frame.CFHeader = h + var numSettings uint32 + if err := binary.Read(f.r, binary.BigEndian, &numSettings); err != nil { + return err + } + frame.FlagIdValues = make([]SettingsFlagIdValue, numSettings) + for i := uint32(0); i < numSettings; i++ { + if err := binary.Read(f.r, binary.BigEndian, &frame.FlagIdValues[i].Id); err != nil { + return err + } + frame.FlagIdValues[i].Flag = SettingsFlag((frame.FlagIdValues[i].Id & 0xff000000) >> 24) + frame.FlagIdValues[i].Id &= 0xffffff + if err := binary.Read(f.r, binary.BigEndian, &frame.FlagIdValues[i].Value); err != nil { + return err + } + } + return nil +} + +func (frame *NoopFrame) read(h ControlFrameHeader, f *Framer) os.Error { + frame.CFHeader = h + return nil +} + +func (frame *PingFrame) read(h ControlFrameHeader, f *Framer) os.Error { + frame.CFHeader = h + if err := binary.Read(f.r, binary.BigEndian, &frame.Id); err != nil { + return err + } + return nil +} + +func (frame *GoAwayFrame) read(h ControlFrameHeader, f *Framer) os.Error { + frame.CFHeader = h + if err := binary.Read(f.r, binary.BigEndian, &frame.LastGoodStreamId); err != nil { + return err + } + return nil +} + +func (frame *HeadersFrame) read(h ControlFrameHeader, f *Framer) os.Error { + return f.readHeadersFrame(h, frame) +} + +func newControlFrame(frameType ControlFrameType) (controlFrame, os.Error) { + ctor, ok := cframeCtor[frameType] + if !ok { + return nil, InvalidControlFrame + } + return ctor(), nil +} + +var cframeCtor = map[ControlFrameType]func() controlFrame{ + TypeSynStream: func() controlFrame { return new(SynStreamFrame) }, + TypeSynReply: func() controlFrame { return new(SynReplyFrame) }, + TypeRstStream: func() controlFrame { return new(RstStreamFrame) }, + TypeSettings: func() controlFrame { return new(SettingsFrame) }, + TypeNoop: func() controlFrame { return new(NoopFrame) }, + TypePing: func() controlFrame { return new(PingFrame) }, + TypeGoAway: func() controlFrame { return new(GoAwayFrame) }, + TypeHeaders: func() controlFrame { return new(HeadersFrame) }, + // TODO(willchan): Add TypeWindowUpdate +} + +type corkedReader struct { + r io.Reader + ch chan int + n int +} + +func (cr *corkedReader) Read(p []byte) (int, os.Error) { + if cr.n == 0 { + cr.n = <-cr.ch + } + if len(p) > cr.n { + p = p[:cr.n] + } + n, err := cr.r.Read(p) + cr.n -= n + return n, err +} + +func (f *Framer) uncorkHeaderDecompressor(payloadSize int) os.Error { + if f.headerDecompressor != nil { + f.headerReader.ch <- payloadSize + return nil + } + f.headerReader = corkedReader{r: f.r, ch: make(chan int, 1), n: payloadSize} + decompressor, err := zlib.NewReaderDict(&f.headerReader, []byte(HeaderDictionary)) + if err != nil { + return err + } + f.headerDecompressor = decompressor + return nil +} + +// ReadFrame reads SPDY encoded data and returns a decompressed Frame. +func (f *Framer) ReadFrame() (Frame, os.Error) { + var firstWord uint32 + if err := binary.Read(f.r, binary.BigEndian, &firstWord); err != nil { + return nil, err + } + if (firstWord & 0x80000000) != 0 { + frameType := ControlFrameType(firstWord & 0xffff) + version := uint16(0x7fff & (firstWord >> 16)) + return f.parseControlFrame(version, frameType) + } + return f.parseDataFrame(firstWord & 0x7fffffff) +} + +func (f *Framer) parseControlFrame(version uint16, frameType ControlFrameType) (Frame, os.Error) { + var length uint32 + if err := binary.Read(f.r, binary.BigEndian, &length); err != nil { + return nil, err + } + flags := ControlFlags((length & 0xff000000) >> 24) + length &= 0xffffff + header := ControlFrameHeader{version, frameType, flags, length} + cframe, err := newControlFrame(frameType) + if err != nil { + return nil, err + } + if err = cframe.read(header, f); err != nil { + return nil, err + } + return cframe, nil +} + +func parseHeaderValueBlock(r io.Reader) (http.Header, os.Error) { + var numHeaders uint16 + if err := binary.Read(r, binary.BigEndian, &numHeaders); err != nil { + return nil, err + } + h := make(http.Header, int(numHeaders)) + for i := 0; i < int(numHeaders); i++ { + var length uint16 + if err := binary.Read(r, binary.BigEndian, &length); err != nil { + return nil, err + } + nameBytes := make([]byte, length) + if _, err := io.ReadFull(r, nameBytes); err != nil { + return nil, err + } + name := string(nameBytes) + if name != strings.ToLower(name) { + return nil, UnlowercasedHeaderName + } + if h[name] != nil { + return nil, DuplicateHeaders + } + if err := binary.Read(r, binary.BigEndian, &length); err != nil { + return nil, err + } + value := make([]byte, length) + if _, err := io.ReadFull(r, value); err != nil { + return nil, err + } + valueList := strings.Split(string(value), "\x00", -1) + for _, v := range valueList { + h.Add(name, v) + } + } + return h, nil +} + +func (f *Framer) readSynStreamFrame(h ControlFrameHeader, frame *SynStreamFrame) os.Error { + frame.CFHeader = h + var err os.Error + if err = binary.Read(f.r, binary.BigEndian, &frame.StreamId); err != nil { + return err + } + if err = binary.Read(f.r, binary.BigEndian, &frame.AssociatedToStreamId); err != nil { + return err + } + if err = binary.Read(f.r, binary.BigEndian, &frame.Priority); err != nil { + return err + } + frame.Priority >>= 14 + + reader := f.r + if !f.headerCompressionDisabled { + f.uncorkHeaderDecompressor(int(h.length - 10)) + reader = f.headerDecompressor + } + + frame.Headers, err = parseHeaderValueBlock(reader) + if err != nil { + return err + } + return nil +} + +func (f *Framer) readSynReplyFrame(h ControlFrameHeader, frame *SynReplyFrame) os.Error { + frame.CFHeader = h + var err os.Error + if err = binary.Read(f.r, binary.BigEndian, &frame.StreamId); err != nil { + return err + } + var unused uint16 + if err = binary.Read(f.r, binary.BigEndian, &unused); err != nil { + return err + } + reader := f.r + if !f.headerCompressionDisabled { + f.uncorkHeaderDecompressor(int(h.length - 6)) + reader = f.headerDecompressor + } + frame.Headers, err = parseHeaderValueBlock(reader) + if err != nil { + return err + } + return nil +} + +func (f *Framer) readHeadersFrame(h ControlFrameHeader, frame *HeadersFrame) os.Error { + frame.CFHeader = h + var err os.Error + if err = binary.Read(f.r, binary.BigEndian, &frame.StreamId); err != nil { + return err + } + var unused uint16 + if err = binary.Read(f.r, binary.BigEndian, &unused); err != nil { + return err + } + reader := f.r + if !f.headerCompressionDisabled { + f.uncorkHeaderDecompressor(int(h.length - 6)) + reader = f.headerDecompressor + } + frame.Headers, err = parseHeaderValueBlock(reader) + if err != nil { + return err + } + return nil +} + +func (f *Framer) parseDataFrame(streamId uint32) (*DataFrame, os.Error) { + var length uint32 + if err := binary.Read(f.r, binary.BigEndian, &length); err != nil { + return nil, err + } + var frame DataFrame + frame.StreamId = streamId + frame.Flags = DataFlags(length >> 24) + length &= 0xffffff + frame.Data = make([]byte, length) + // TODO(willchan): Support compressed data frames. + if _, err := io.ReadFull(f.r, frame.Data); err != nil { + return nil, err + } + return &frame, nil +} diff --git a/src/pkg/http/spdy/spdy_test.go b/src/pkg/http/spdy/spdy_test.go new file mode 100644 index 000000000..9100e1ea8 --- /dev/null +++ b/src/pkg/http/spdy/spdy_test.go @@ -0,0 +1,496 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package spdy + +import ( + "bytes" + "http" + "io" + "reflect" + "testing" +) + +func TestHeaderParsing(t *testing.T) { + headers := http.Header{ + "Url": []string{"http://www.google.com/"}, + "Method": []string{"get"}, + "Version": []string{"http/1.1"}, + } + var headerValueBlockBuf bytes.Buffer + writeHeaderValueBlock(&headerValueBlockBuf, headers) + + newHeaders, err := parseHeaderValueBlock(&headerValueBlockBuf) + if err != nil { + t.Fatal("parseHeaderValueBlock:", err) + } + + if !reflect.DeepEqual(headers, newHeaders) { + t.Fatal("got: ", newHeaders, "\nwant: ", headers) + } +} + +func TestCreateParseSynStreamFrame(t *testing.T) { + buffer := new(bytes.Buffer) + framer := &Framer{ + headerCompressionDisabled: true, + w: buffer, + headerBuf: new(bytes.Buffer), + r: buffer, + } + synStreamFrame := SynStreamFrame{ + CFHeader: ControlFrameHeader{ + version: Version, + frameType: TypeSynStream, + }, + Headers: http.Header{ + "Url": []string{"http://www.google.com/"}, + "Method": []string{"get"}, + "Version": []string{"http/1.1"}, + }, + } + if err := framer.WriteFrame(&synStreamFrame); err != nil { + t.Fatal("WriteFrame without compression:", err) + } + frame, err := framer.ReadFrame() + if err != nil { + t.Fatal("ReadFrame without compression:", err) + } + parsedSynStreamFrame, ok := frame.(*SynStreamFrame) + if !ok { + t.Fatal("Parsed incorrect frame type:", frame) + } + if !reflect.DeepEqual(synStreamFrame, *parsedSynStreamFrame) { + t.Fatal("got: ", *parsedSynStreamFrame, "\nwant: ", synStreamFrame) + } + + // Test again with compression + buffer.Reset() + framer, err = NewFramer(buffer, buffer) + if err != nil { + t.Fatal("Failed to create new framer:", err) + } + if err := framer.WriteFrame(&synStreamFrame); err != nil { + t.Fatal("WriteFrame with compression:", err) + } + frame, err = framer.ReadFrame() + if err != nil { + t.Fatal("ReadFrame with compression:", err) + } + parsedSynStreamFrame, ok = frame.(*SynStreamFrame) + if !ok { + t.Fatal("Parsed incorrect frame type:", frame) + } + if !reflect.DeepEqual(synStreamFrame, *parsedSynStreamFrame) { + t.Fatal("got: ", *parsedSynStreamFrame, "\nwant: ", synStreamFrame) + } +} + +func TestCreateParseSynReplyFrame(t *testing.T) { + buffer := new(bytes.Buffer) + framer := &Framer{ + headerCompressionDisabled: true, + w: buffer, + headerBuf: new(bytes.Buffer), + r: buffer, + } + synReplyFrame := SynReplyFrame{ + CFHeader: ControlFrameHeader{ + version: Version, + frameType: TypeSynReply, + }, + Headers: http.Header{ + "Url": []string{"http://www.google.com/"}, + "Method": []string{"get"}, + "Version": []string{"http/1.1"}, + }, + } + if err := framer.WriteFrame(&synReplyFrame); err != nil { + t.Fatal("WriteFrame without compression:", err) + } + frame, err := framer.ReadFrame() + if err != nil { + t.Fatal("ReadFrame without compression:", err) + } + parsedSynReplyFrame, ok := frame.(*SynReplyFrame) + if !ok { + t.Fatal("Parsed incorrect frame type:", frame) + } + if !reflect.DeepEqual(synReplyFrame, *parsedSynReplyFrame) { + t.Fatal("got: ", *parsedSynReplyFrame, "\nwant: ", synReplyFrame) + } + + // Test again with compression + buffer.Reset() + framer, err = NewFramer(buffer, buffer) + if err != nil { + t.Fatal("Failed to create new framer:", err) + } + if err := framer.WriteFrame(&synReplyFrame); err != nil { + t.Fatal("WriteFrame with compression:", err) + } + frame, err = framer.ReadFrame() + if err != nil { + t.Fatal("ReadFrame with compression:", err) + } + parsedSynReplyFrame, ok = frame.(*SynReplyFrame) + if !ok { + t.Fatal("Parsed incorrect frame type:", frame) + } + if !reflect.DeepEqual(synReplyFrame, *parsedSynReplyFrame) { + t.Fatal("got: ", *parsedSynReplyFrame, "\nwant: ", synReplyFrame) + } +} + +func TestCreateParseRstStream(t *testing.T) { + buffer := new(bytes.Buffer) + framer, err := NewFramer(buffer, buffer) + if err != nil { + t.Fatal("Failed to create new framer:", err) + } + rstStreamFrame := RstStreamFrame{ + CFHeader: ControlFrameHeader{ + version: Version, + frameType: TypeRstStream, + }, + StreamId: 1, + Status: InvalidStream, + } + if err := framer.WriteFrame(&rstStreamFrame); err != nil { + t.Fatal("WriteFrame:", err) + } + frame, err := framer.ReadFrame() + if err != nil { + t.Fatal("ReadFrame:", err) + } + parsedRstStreamFrame, ok := frame.(*RstStreamFrame) + if !ok { + t.Fatal("Parsed incorrect frame type:", frame) + } + if !reflect.DeepEqual(rstStreamFrame, *parsedRstStreamFrame) { + t.Fatal("got: ", *parsedRstStreamFrame, "\nwant: ", rstStreamFrame) + } +} + +func TestCreateParseSettings(t *testing.T) { + buffer := new(bytes.Buffer) + framer, err := NewFramer(buffer, buffer) + if err != nil { + t.Fatal("Failed to create new framer:", err) + } + settingsFrame := SettingsFrame{ + CFHeader: ControlFrameHeader{ + version: Version, + frameType: TypeSettings, + }, + FlagIdValues: []SettingsFlagIdValue{ + {FlagSettingsPersistValue, SettingsCurrentCwnd, 10}, + {FlagSettingsPersisted, SettingsUploadBandwidth, 1}, + }, + } + if err := framer.WriteFrame(&settingsFrame); err != nil { + t.Fatal("WriteFrame:", err) + } + frame, err := framer.ReadFrame() + if err != nil { + t.Fatal("ReadFrame:", err) + } + parsedSettingsFrame, ok := frame.(*SettingsFrame) + if !ok { + t.Fatal("Parsed incorrect frame type:", frame) + } + if !reflect.DeepEqual(settingsFrame, *parsedSettingsFrame) { + t.Fatal("got: ", *parsedSettingsFrame, "\nwant: ", settingsFrame) + } +} + +func TestCreateParseNoop(t *testing.T) { + buffer := new(bytes.Buffer) + framer, err := NewFramer(buffer, buffer) + if err != nil { + t.Fatal("Failed to create new framer:", err) + } + noopFrame := NoopFrame{ + CFHeader: ControlFrameHeader{ + version: Version, + frameType: TypeNoop, + }, + } + if err := framer.WriteFrame(&noopFrame); err != nil { + t.Fatal("WriteFrame:", err) + } + frame, err := framer.ReadFrame() + if err != nil { + t.Fatal("ReadFrame:", err) + } + parsedNoopFrame, ok := frame.(*NoopFrame) + if !ok { + t.Fatal("Parsed incorrect frame type:", frame) + } + if !reflect.DeepEqual(noopFrame, *parsedNoopFrame) { + t.Fatal("got: ", *parsedNoopFrame, "\nwant: ", noopFrame) + } +} + +func TestCreateParsePing(t *testing.T) { + buffer := new(bytes.Buffer) + framer, err := NewFramer(buffer, buffer) + if err != nil { + t.Fatal("Failed to create new framer:", err) + } + pingFrame := PingFrame{ + CFHeader: ControlFrameHeader{ + version: Version, + frameType: TypePing, + }, + Id: 31337, + } + if err := framer.WriteFrame(&pingFrame); err != nil { + t.Fatal("WriteFrame:", err) + } + frame, err := framer.ReadFrame() + if err != nil { + t.Fatal("ReadFrame:", err) + } + parsedPingFrame, ok := frame.(*PingFrame) + if !ok { + t.Fatal("Parsed incorrect frame type:", frame) + } + if !reflect.DeepEqual(pingFrame, *parsedPingFrame) { + t.Fatal("got: ", *parsedPingFrame, "\nwant: ", pingFrame) + } +} + +func TestCreateParseGoAway(t *testing.T) { + buffer := new(bytes.Buffer) + framer, err := NewFramer(buffer, buffer) + if err != nil { + t.Fatal("Failed to create new framer:", err) + } + goAwayFrame := GoAwayFrame{ + CFHeader: ControlFrameHeader{ + version: Version, + frameType: TypeGoAway, + }, + LastGoodStreamId: 31337, + } + if err := framer.WriteFrame(&goAwayFrame); err != nil { + t.Fatal("WriteFrame:", err) + } + frame, err := framer.ReadFrame() + if err != nil { + t.Fatal("ReadFrame:", err) + } + parsedGoAwayFrame, ok := frame.(*GoAwayFrame) + if !ok { + t.Fatal("Parsed incorrect frame type:", frame) + } + if !reflect.DeepEqual(goAwayFrame, *parsedGoAwayFrame) { + t.Fatal("got: ", *parsedGoAwayFrame, "\nwant: ", goAwayFrame) + } +} + +func TestCreateParseHeadersFrame(t *testing.T) { + buffer := new(bytes.Buffer) + framer := &Framer{ + headerCompressionDisabled: true, + w: buffer, + headerBuf: new(bytes.Buffer), + r: buffer, + } + headersFrame := HeadersFrame{ + CFHeader: ControlFrameHeader{ + version: Version, + frameType: TypeHeaders, + }, + } + headersFrame.Headers = http.Header{ + "Url": []string{"http://www.google.com/"}, + "Method": []string{"get"}, + "Version": []string{"http/1.1"}, + } + if err := framer.WriteFrame(&headersFrame); err != nil { + t.Fatal("WriteFrame without compression:", err) + } + frame, err := framer.ReadFrame() + if err != nil { + t.Fatal("ReadFrame without compression:", err) + } + parsedHeadersFrame, ok := frame.(*HeadersFrame) + if !ok { + t.Fatal("Parsed incorrect frame type:", frame) + } + if !reflect.DeepEqual(headersFrame, *parsedHeadersFrame) { + t.Fatal("got: ", *parsedHeadersFrame, "\nwant: ", headersFrame) + } + + // Test again with compression + buffer.Reset() + framer, err = NewFramer(buffer, buffer) + if err := framer.WriteFrame(&headersFrame); err != nil { + t.Fatal("WriteFrame with compression:", err) + } + frame, err = framer.ReadFrame() + if err != nil { + t.Fatal("ReadFrame with compression:", err) + } + parsedHeadersFrame, ok = frame.(*HeadersFrame) + if !ok { + t.Fatal("Parsed incorrect frame type:", frame) + } + if !reflect.DeepEqual(headersFrame, *parsedHeadersFrame) { + t.Fatal("got: ", *parsedHeadersFrame, "\nwant: ", headersFrame) + } +} + +func TestCreateParseDataFrame(t *testing.T) { + buffer := new(bytes.Buffer) + framer, err := NewFramer(buffer, buffer) + if err != nil { + t.Fatal("Failed to create new framer:", err) + } + dataFrame := DataFrame{ + StreamId: 1, + Data: []byte{'h', 'e', 'l', 'l', 'o'}, + } + if err := framer.WriteFrame(&dataFrame); err != nil { + t.Fatal("WriteFrame:", err) + } + frame, err := framer.ReadFrame() + if err != nil { + t.Fatal("ReadFrame:", err) + } + parsedDataFrame, ok := frame.(*DataFrame) + if !ok { + t.Fatal("Parsed incorrect frame type:", frame) + } + if !reflect.DeepEqual(dataFrame, *parsedDataFrame) { + t.Fatal("got: ", *parsedDataFrame, "\nwant: ", dataFrame) + } +} + +func TestCompressionContextAcrossFrames(t *testing.T) { + buffer := new(bytes.Buffer) + framer, err := NewFramer(buffer, buffer) + if err != nil { + t.Fatal("Failed to create new framer:", err) + } + headersFrame := HeadersFrame{ + CFHeader: ControlFrameHeader{ + version: Version, + frameType: TypeHeaders, + }, + Headers: http.Header{ + "Url": []string{"http://www.google.com/"}, + "Method": []string{"get"}, + "Version": []string{"http/1.1"}, + }, + } + if err := framer.WriteFrame(&headersFrame); err != nil { + t.Fatal("WriteFrame (HEADERS):", err) + } + synStreamFrame := SynStreamFrame{ControlFrameHeader{Version, TypeSynStream, 0, 0}, 0, 0, 0, nil} + synStreamFrame.Headers = http.Header{ + "Url": []string{"http://www.google.com/"}, + "Method": []string{"get"}, + "Version": []string{"http/1.1"}, + } + if err := framer.WriteFrame(&synStreamFrame); err != nil { + t.Fatal("WriteFrame (SYN_STREAM):", err) + } + frame, err := framer.ReadFrame() + if err != nil { + t.Fatal("ReadFrame (HEADERS):", err, buffer.Bytes()) + } + parsedHeadersFrame, ok := frame.(*HeadersFrame) + if !ok { + t.Fatalf("expected HeadersFrame; got %T %v", frame, frame) + } + if !reflect.DeepEqual(headersFrame, *parsedHeadersFrame) { + t.Fatal("got: ", *parsedHeadersFrame, "\nwant: ", headersFrame) + } + frame, err = framer.ReadFrame() + if err != nil { + t.Fatal("ReadFrame (SYN_STREAM):", err, buffer.Bytes()) + } + parsedSynStreamFrame, ok := frame.(*SynStreamFrame) + if !ok { + t.Fatalf("expected SynStreamFrame; got %T %v", frame, frame) + } + if !reflect.DeepEqual(synStreamFrame, *parsedSynStreamFrame) { + t.Fatal("got: ", *parsedSynStreamFrame, "\nwant: ", synStreamFrame) + } +} + +func TestMultipleSPDYFrames(t *testing.T) { + // Initialize the framers. + pr1, pw1 := io.Pipe() + pr2, pw2 := io.Pipe() + writer, err := NewFramer(pw1, pr2) + if err != nil { + t.Fatal("Failed to create writer:", err) + } + reader, err := NewFramer(pw2, pr1) + if err != nil { + t.Fatal("Failed to create reader:", err) + } + + // Set up the frames we're actually transferring. + headersFrame := HeadersFrame{ + CFHeader: ControlFrameHeader{ + version: Version, + frameType: TypeHeaders, + }, + Headers: http.Header{ + "Url": []string{"http://www.google.com/"}, + "Method": []string{"get"}, + "Version": []string{"http/1.1"}, + }, + } + synStreamFrame := SynStreamFrame{ + CFHeader: ControlFrameHeader{ + version: Version, + frameType: TypeSynStream, + }, + Headers: http.Header{ + "Url": []string{"http://www.google.com/"}, + "Method": []string{"get"}, + "Version": []string{"http/1.1"}, + }, + } + + // Start the goroutines to write the frames. + go func() { + if err := writer.WriteFrame(&headersFrame); err != nil { + t.Fatal("WriteFrame (HEADERS): ", err) + } + if err := writer.WriteFrame(&synStreamFrame); err != nil { + t.Fatal("WriteFrame (SYN_STREAM): ", err) + } + }() + + // Read the frames and verify they look as expected. + frame, err := reader.ReadFrame() + if err != nil { + t.Fatal("ReadFrame (HEADERS): ", err) + } + parsedHeadersFrame, ok := frame.(*HeadersFrame) + if !ok { + t.Fatal("Parsed incorrect frame type:", frame) + } + if !reflect.DeepEqual(headersFrame, *parsedHeadersFrame) { + t.Fatal("got: ", *parsedHeadersFrame, "\nwant: ", headersFrame) + } + frame, err = reader.ReadFrame() + if err != nil { + t.Fatal("ReadFrame (SYN_STREAM):", err) + } + parsedSynStreamFrame, ok := frame.(*SynStreamFrame) + if !ok { + t.Fatal("Parsed incorrect frame type.") + } + if !reflect.DeepEqual(synStreamFrame, *parsedSynStreamFrame) { + t.Fatal("got: ", *parsedSynStreamFrame, "\nwant: ", synStreamFrame) + } +} diff --git a/src/pkg/http/spdy/types.go b/src/pkg/http/spdy/types.go new file mode 100644 index 000000000..5a665f04f --- /dev/null +++ b/src/pkg/http/spdy/types.go @@ -0,0 +1,363 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package spdy + +import ( + "bytes" + "compress/zlib" + "http" + "io" + "os" + "strconv" +) + +// Data Frame Format +// +----------------------------------+ +// |0| Stream-ID (31bits) | +// +----------------------------------+ +// | flags (8) | Length (24 bits) | +// +----------------------------------+ +// | Data | +// +----------------------------------+ +// +// Control Frame Format +// +----------------------------------+ +// |1| Version(15bits) | Type(16bits) | +// +----------------------------------+ +// | flags (8) | Length (24 bits) | +// +----------------------------------+ +// | Data | +// +----------------------------------+ +// +// Control Frame: SYN_STREAM +// +----------------------------------+ +// |1|000000000000001|0000000000000001| +// +----------------------------------+ +// | flags (8) | Length (24 bits) | >= 12 +// +----------------------------------+ +// |X| Stream-ID(31bits) | +// +----------------------------------+ +// |X|Associated-To-Stream-ID (31bits)| +// +----------------------------------+ +// |Pri| unused | Length (16bits)| +// +----------------------------------+ +// +// Control Frame: SYN_REPLY +// +----------------------------------+ +// |1|000000000000001|0000000000000010| +// +----------------------------------+ +// | flags (8) | Length (24 bits) | >= 8 +// +----------------------------------+ +// |X| Stream-ID(31bits) | +// +----------------------------------+ +// | unused (16 bits)| Length (16bits)| +// +----------------------------------+ +// +// Control Frame: RST_STREAM +// +----------------------------------+ +// |1|000000000000001|0000000000000011| +// +----------------------------------+ +// | flags (8) | Length (24 bits) | >= 4 +// +----------------------------------+ +// |X| Stream-ID(31bits) | +// +----------------------------------+ +// | Status code (32 bits) | +// +----------------------------------+ +// +// Control Frame: SETTINGS +// +----------------------------------+ +// |1|000000000000001|0000000000000100| +// +----------------------------------+ +// | flags (8) | Length (24 bits) | +// +----------------------------------+ +// | # of entries (32) | +// +----------------------------------+ +// +// Control Frame: NOOP +// +----------------------------------+ +// |1|000000000000001|0000000000000101| +// +----------------------------------+ +// | flags (8) | Length (24 bits) | = 0 +// +----------------------------------+ +// +// Control Frame: PING +// +----------------------------------+ +// |1|000000000000001|0000000000000110| +// +----------------------------------+ +// | flags (8) | Length (24 bits) | = 4 +// +----------------------------------+ +// | Unique id (32 bits) | +// +----------------------------------+ +// +// Control Frame: GOAWAY +// +----------------------------------+ +// |1|000000000000001|0000000000000111| +// +----------------------------------+ +// | flags (8) | Length (24 bits) | = 4 +// +----------------------------------+ +// |X| Last-accepted-stream-id | +// +----------------------------------+ +// +// Control Frame: HEADERS +// +----------------------------------+ +// |1|000000000000001|0000000000001000| +// +----------------------------------+ +// | flags (8) | Length (24 bits) | >= 8 +// +----------------------------------+ +// |X| Stream-ID (31 bits) | +// +----------------------------------+ +// | unused (16 bits)| Length (16bits)| +// +----------------------------------+ +// +// Control Frame: WINDOW_UPDATE +// +----------------------------------+ +// |1|000000000000001|0000000000001001| +// +----------------------------------+ +// | flags (8) | Length (24 bits) | = 8 +// +----------------------------------+ +// |X| Stream-ID (31 bits) | +// +----------------------------------+ +// | Delta-Window-Size (32 bits) | +// +----------------------------------+ + +// Version is the protocol version number that this package implements. +const Version = 2 + +// ControlFrameType stores the type field in a control frame header. +type ControlFrameType uint16 + +// Control frame type constants +const ( + TypeSynStream ControlFrameType = 0x0001 + TypeSynReply = 0x0002 + TypeRstStream = 0x0003 + TypeSettings = 0x0004 + TypeNoop = 0x0005 + TypePing = 0x0006 + TypeGoAway = 0x0007 + TypeHeaders = 0x0008 + TypeWindowUpdate = 0x0009 +) + +// ControlFlags are the flags that can be set on a control frame. +type ControlFlags uint8 + +const ( + ControlFlagFin ControlFlags = 0x01 +) + +// DataFlags are the flags that can be set on a data frame. +type DataFlags uint8 + +const ( + DataFlagFin DataFlags = 0x01 + DataFlagCompressed = 0x02 +) + +// MaxDataLength is the maximum number of bytes that can be stored in one frame. +const MaxDataLength = 1<<24 - 1 + +// Frame is a single SPDY frame in its unpacked in-memory representation. Use +// Framer to read and write it. +type Frame interface { + write(f *Framer) os.Error +} + +// ControlFrameHeader contains all the fields in a control frame header, +// in its unpacked in-memory representation. +type ControlFrameHeader struct { + // Note, high bit is the "Control" bit. + version uint16 + frameType ControlFrameType + Flags ControlFlags + length uint32 +} + +type controlFrame interface { + Frame + read(h ControlFrameHeader, f *Framer) os.Error +} + +// SynStreamFrame is the unpacked, in-memory representation of a SYN_STREAM +// frame. +type SynStreamFrame struct { + CFHeader ControlFrameHeader + StreamId uint32 + AssociatedToStreamId uint32 + // Note, only 2 highest bits currently used + // Rest of Priority is unused. + Priority uint16 + Headers http.Header +} + +// SynReplyFrame is the unpacked, in-memory representation of a SYN_REPLY frame. +type SynReplyFrame struct { + CFHeader ControlFrameHeader + StreamId uint32 + Headers http.Header +} + +// StatusCode represents the status that led to a RST_STREAM +type StatusCode uint32 + +const ( + ProtocolError StatusCode = 1 + InvalidStream = 2 + RefusedStream = 3 + UnsupportedVersion = 4 + Cancel = 5 + InternalError = 6 + FlowControlError = 7 +) + +// RstStreamFrame is the unpacked, in-memory representation of a RST_STREAM +// frame. +type RstStreamFrame struct { + CFHeader ControlFrameHeader + StreamId uint32 + Status StatusCode +} + +// SettingsFlag represents a flag in a SETTINGS frame. +type SettingsFlag uint8 + +const ( + FlagSettingsPersistValue SettingsFlag = 0x1 + FlagSettingsPersisted = 0x2 +) + +// SettingsFlag represents the id of an id/value pair in a SETTINGS frame. +type SettingsId uint32 + +const ( + SettingsUploadBandwidth SettingsId = 1 + SettingsDownloadBandwidth = 2 + SettingsRoundTripTime = 3 + SettingsMaxConcurrentStreams = 4 + SettingsCurrentCwnd = 5 +) + +// SettingsFlagIdValue is the unpacked, in-memory representation of the +// combined flag/id/value for a setting in a SETTINGS frame. +type SettingsFlagIdValue struct { + Flag SettingsFlag + Id SettingsId + Value uint32 +} + +// SettingsFrame is the unpacked, in-memory representation of a SPDY +// SETTINGS frame. +type SettingsFrame struct { + CFHeader ControlFrameHeader + FlagIdValues []SettingsFlagIdValue +} + +// NoopFrame is the unpacked, in-memory representation of a NOOP frame. +type NoopFrame struct { + CFHeader ControlFrameHeader +} + +// PingFrame is the unpacked, in-memory representation of a PING frame. +type PingFrame struct { + CFHeader ControlFrameHeader + Id uint32 +} + +// GoAwayFrame is the unpacked, in-memory representation of a GOAWAY frame. +type GoAwayFrame struct { + CFHeader ControlFrameHeader + LastGoodStreamId uint32 +} + +// HeadersFrame is the unpacked, in-memory representation of a HEADERS frame. +type HeadersFrame struct { + CFHeader ControlFrameHeader + StreamId uint32 + Headers http.Header +} + +// DataFrame is the unpacked, in-memory representation of a DATA frame. +type DataFrame struct { + // Note, high bit is the "Control" bit. Should be 0 for data frames. + StreamId uint32 + Flags DataFlags + Data []byte +} + +// HeaderDictionary is the dictionary sent to the zlib compressor/decompressor. +// Even though the specification states there is no null byte at the end, Chrome sends it. +const HeaderDictionary = "optionsgetheadpostputdeletetrace" + + "acceptaccept-charsetaccept-encodingaccept-languageauthorizationexpectfromhost" + + "if-modified-sinceif-matchif-none-matchif-rangeif-unmodifiedsince" + + "max-forwardsproxy-authorizationrangerefererteuser-agent" + + "100101200201202203204205206300301302303304305306307400401402403404405406407408409410411412413414415416417500501502503504505" + + "accept-rangesageetaglocationproxy-authenticatepublicretry-after" + + "servervarywarningwww-authenticateallowcontent-basecontent-encodingcache-control" + + "connectiondatetrailertransfer-encodingupgradeviawarning" + + "content-languagecontent-lengthcontent-locationcontent-md5content-rangecontent-typeetagexpireslast-modifiedset-cookie" + + "MondayTuesdayWednesdayThursdayFridaySaturdaySunday" + + "JanFebMarAprMayJunJulAugSepOctNovDec" + + "chunkedtext/htmlimage/pngimage/jpgimage/gifapplication/xmlapplication/xhtmltext/plainpublicmax-age" + + "charset=iso-8859-1utf-8gzipdeflateHTTP/1.1statusversionurl\x00" + +type FramerError int + +const ( + Internal FramerError = iota + InvalidControlFrame + UnlowercasedHeaderName + DuplicateHeaders + UnknownFrameType + InvalidDataFrame +) + +func (e FramerError) String() string { + switch e { + case Internal: + return "Internal" + case InvalidControlFrame: + return "InvalidControlFrame" + case UnlowercasedHeaderName: + return "UnlowercasedHeaderName" + case DuplicateHeaders: + return "DuplicateHeaders" + case UnknownFrameType: + return "UnknownFrameType" + case InvalidDataFrame: + return "InvalidDataFrame" + } + return "Error(" + strconv.Itoa(int(e)) + ")" +} + +// Framer handles serializing/deserializing SPDY frames, including compressing/ +// decompressing payloads. +type Framer struct { + headerCompressionDisabled bool + w io.Writer + headerBuf *bytes.Buffer + headerCompressor *zlib.Writer + r io.Reader + headerReader corkedReader + headerDecompressor io.ReadCloser +} + +// NewFramer allocates a new Framer for a given SPDY connection, repesented by +// a io.Writer and io.Reader. Note that Framer will read and write individual fields +// from/to the Reader and Writer, so the caller should pass in an appropriately +// buffered implementation to optimize performance. +func NewFramer(w io.Writer, r io.Reader) (*Framer, os.Error) { + compressBuf := new(bytes.Buffer) + compressor, err := zlib.NewWriterDict(compressBuf, zlib.BestCompression, []byte(HeaderDictionary)) + if err != nil { + return nil, err + } + framer := &Framer{ + w: w, + headerBuf: compressBuf, + headerCompressor: compressor, + r: r, + } + return framer, nil +} diff --git a/src/pkg/http/spdy/write.go b/src/pkg/http/spdy/write.go new file mode 100644 index 000000000..aa1679f1b --- /dev/null +++ b/src/pkg/http/spdy/write.go @@ -0,0 +1,287 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package spdy + +import ( + "encoding/binary" + "http" + "io" + "os" + "strings" +) + +func (frame *SynStreamFrame) write(f *Framer) os.Error { + return f.writeSynStreamFrame(frame) +} + +func (frame *SynReplyFrame) write(f *Framer) os.Error { + return f.writeSynReplyFrame(frame) +} + +func (frame *RstStreamFrame) write(f *Framer) (err os.Error) { + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypeRstStream + frame.CFHeader.length = 8 + + // Serialize frame to Writer + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, frame.Status); err != nil { + return + } + return +} + +func (frame *SettingsFrame) write(f *Framer) (err os.Error) { + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypeSettings + frame.CFHeader.length = uint32(len(frame.FlagIdValues)*8 + 4) + + // Serialize frame to Writer + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, uint32(len(frame.FlagIdValues))); err != nil { + return + } + for _, flagIdValue := range frame.FlagIdValues { + flagId := (uint32(flagIdValue.Flag) << 24) | uint32(flagIdValue.Id) + if err = binary.Write(f.w, binary.BigEndian, flagId); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, flagIdValue.Value); err != nil { + return + } + } + return +} + +func (frame *NoopFrame) write(f *Framer) os.Error { + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypeNoop + + // Serialize frame to Writer + return writeControlFrameHeader(f.w, frame.CFHeader) +} + +func (frame *PingFrame) write(f *Framer) (err os.Error) { + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypePing + frame.CFHeader.length = 4 + + // Serialize frame to Writer + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, frame.Id); err != nil { + return + } + return +} + +func (frame *GoAwayFrame) write(f *Framer) (err os.Error) { + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypeGoAway + frame.CFHeader.length = 4 + + // Serialize frame to Writer + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, frame.LastGoodStreamId); err != nil { + return + } + return nil +} + +func (frame *HeadersFrame) write(f *Framer) os.Error { + return f.writeHeadersFrame(frame) +} + +func (frame *DataFrame) write(f *Framer) os.Error { + return f.writeDataFrame(frame) +} + +// WriteFrame writes a frame. +func (f *Framer) WriteFrame(frame Frame) os.Error { + return frame.write(f) +} + +func writeControlFrameHeader(w io.Writer, h ControlFrameHeader) os.Error { + if err := binary.Write(w, binary.BigEndian, 0x8000|h.version); err != nil { + return err + } + if err := binary.Write(w, binary.BigEndian, h.frameType); err != nil { + return err + } + flagsAndLength := (uint32(h.Flags) << 24) | h.length + if err := binary.Write(w, binary.BigEndian, flagsAndLength); err != nil { + return err + } + return nil +} + +func writeHeaderValueBlock(w io.Writer, h http.Header) (n int, err os.Error) { + n = 0 + if err = binary.Write(w, binary.BigEndian, uint16(len(h))); err != nil { + return + } + n += 2 + for name, values := range h { + if err = binary.Write(w, binary.BigEndian, uint16(len(name))); err != nil { + return + } + n += 2 + name = strings.ToLower(name) + if _, err = io.WriteString(w, name); err != nil { + return + } + n += len(name) + v := strings.Join(values, "\x00") + if err = binary.Write(w, binary.BigEndian, uint16(len(v))); err != nil { + return + } + n += 2 + if _, err = io.WriteString(w, v); err != nil { + return + } + n += len(v) + } + return +} + +func (f *Framer) writeSynStreamFrame(frame *SynStreamFrame) (err os.Error) { + // Marshal the headers. + var writer io.Writer = f.headerBuf + if !f.headerCompressionDisabled { + writer = f.headerCompressor + } + if _, err = writeHeaderValueBlock(writer, frame.Headers); err != nil { + return + } + if !f.headerCompressionDisabled { + f.headerCompressor.Flush() + } + + // Set ControlFrameHeader + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypeSynStream + frame.CFHeader.length = uint32(len(f.headerBuf.Bytes()) + 10) + + // Serialize frame to Writer + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return err + } + if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { + return err + } + if err = binary.Write(f.w, binary.BigEndian, frame.AssociatedToStreamId); err != nil { + return err + } + if err = binary.Write(f.w, binary.BigEndian, frame.Priority<<14); err != nil { + return err + } + if _, err = f.w.Write(f.headerBuf.Bytes()); err != nil { + return err + } + f.headerBuf.Reset() + return nil +} + +func (f *Framer) writeSynReplyFrame(frame *SynReplyFrame) (err os.Error) { + // Marshal the headers. + var writer io.Writer = f.headerBuf + if !f.headerCompressionDisabled { + writer = f.headerCompressor + } + if _, err = writeHeaderValueBlock(writer, frame.Headers); err != nil { + return + } + if !f.headerCompressionDisabled { + f.headerCompressor.Flush() + } + + // Set ControlFrameHeader + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypeSynReply + frame.CFHeader.length = uint32(len(f.headerBuf.Bytes()) + 6) + + // Serialize frame to Writer + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, uint16(0)); err != nil { + return + } + if _, err = f.w.Write(f.headerBuf.Bytes()); err != nil { + return + } + f.headerBuf.Reset() + return +} + +func (f *Framer) writeHeadersFrame(frame *HeadersFrame) (err os.Error) { + // Marshal the headers. + var writer io.Writer = f.headerBuf + if !f.headerCompressionDisabled { + writer = f.headerCompressor + } + if _, err = writeHeaderValueBlock(writer, frame.Headers); err != nil { + return + } + if !f.headerCompressionDisabled { + f.headerCompressor.Flush() + } + + // Set ControlFrameHeader + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypeHeaders + frame.CFHeader.length = uint32(len(f.headerBuf.Bytes()) + 6) + + // Serialize frame to Writer + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, uint16(0)); err != nil { + return + } + if _, err = f.w.Write(f.headerBuf.Bytes()); err != nil { + return + } + f.headerBuf.Reset() + return +} + +func (f *Framer) writeDataFrame(frame *DataFrame) (err os.Error) { + // Validate DataFrame + if frame.StreamId&0x80000000 != 0 || len(frame.Data) >= 0x0f000000 { + return InvalidDataFrame + } + + // TODO(willchan): Support data compression. + // Serialize frame to Writer + if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { + return + } + flagsAndLength := (uint32(frame.Flags) << 24) | uint32(len(frame.Data)) + if err = binary.Write(f.w, binary.BigEndian, flagsAndLength); err != nil { + return + } + if _, err = f.w.Write(frame.Data); err != nil { + return + } + + return nil +} diff --git a/src/pkg/http/transfer.go b/src/pkg/http/transfer.go index 0fa8bed43..b54508e7a 100644 --- a/src/pkg/http/transfer.go +++ b/src/pkg/http/transfer.go @@ -38,6 +38,9 @@ func newTransferWriter(r interface{}) (t *transferWriter, err os.Error) { t.TransferEncoding = rr.TransferEncoding t.Trailer = rr.Trailer atLeastHTTP11 = rr.ProtoAtLeast(1, 1) + if t.Body != nil && t.ContentLength <= 0 && len(t.TransferEncoding) == 0 && atLeastHTTP11 { + t.TransferEncoding = []string{"chunked"} + } case *Response: t.Body = rr.Body t.ContentLength = rr.ContentLength @@ -45,7 +48,7 @@ func newTransferWriter(r interface{}) (t *transferWriter, err os.Error) { t.TransferEncoding = rr.TransferEncoding t.Trailer = rr.Trailer atLeastHTTP11 = rr.ProtoAtLeast(1, 1) - t.ResponseToHEAD = noBodyExpected(rr.RequestMethod) + t.ResponseToHEAD = noBodyExpected(rr.Request.Method) } // Sanitize Body,ContentLength,TransferEncoding @@ -95,7 +98,7 @@ func (t *transferWriter) WriteHeader(w io.Writer) (err os.Error) { if err != nil { return } - } else if t.ContentLength > 0 || t.ResponseToHEAD { + } else if t.ContentLength > 0 || t.ResponseToHEAD || (t.ContentLength == 0 && isIdentity(t.TransferEncoding)) { io.WriteString(w, "Content-Length: ") _, err = io.WriteString(w, strconv.Itoa64(t.ContentLength)+"\r\n") if err != nil { @@ -196,7 +199,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err os.Error) { case *Response: t.Header = rr.Header t.StatusCode = rr.StatusCode - t.RequestMethod = rr.RequestMethod + t.RequestMethod = rr.Request.Method t.ProtoMajor = rr.ProtoMajor t.ProtoMinor = rr.ProtoMinor t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header) @@ -289,6 +292,9 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err os.Error) { // Checks whether chunked is part of the encodings stack func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } +// Checks whether the encoding is explicitly "identity". +func isIdentity(te []string) bool { return len(te) == 1 && te[0] == "identity" } + // Sanitize transfer encoding func fixTransferEncoding(requestMethod string, header Header) ([]string, os.Error) { raw, present := header["Transfer-Encoding"] diff --git a/src/pkg/http/transport.go b/src/pkg/http/transport.go index 73a2c2191..c907d85fd 100644 --- a/src/pkg/http/transport.go +++ b/src/pkg/http/transport.go @@ -6,12 +6,12 @@ package http import ( "bufio" - "bytes" "compress/gzip" "crypto/tls" "encoding/base64" "fmt" "io" + "io/ioutil" "log" "net" "os" @@ -24,7 +24,7 @@ import ( // each call to Do and uses HTTP proxies as directed by the // $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy) // environment variables. -var DefaultTransport RoundTripper = &Transport{} +var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment} // DefaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. @@ -36,12 +36,23 @@ const DefaultMaxIdleConnsPerHost = 2 type Transport struct { lk sync.Mutex idleConn map[string][]*persistConn + altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper // TODO: tunable on global max cached connections // TODO: tunable on timeout on cached connections // TODO: optional pipelining - IgnoreEnvironment bool // don't look at environment variables for proxy configuration + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*Request) (*URL, os.Error) + + // Dial specifies the dial function for creating TCP + // connections. + // If Dial is nil, net.Dial is used. + Dial func(net, addr string) (c net.Conn, err os.Error) + DisableKeepAlives bool DisableCompression bool @@ -51,6 +62,39 @@ type Transport struct { MaxIdleConnsPerHost int } +// ProxyFromEnvironment returns the URL of the proxy to use for a +// given request, as indicated by the environment variables +// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy). +// Either URL or an error is returned. +func ProxyFromEnvironment(req *Request) (*URL, os.Error) { + proxy := getenvEitherCase("HTTP_PROXY") + if proxy == "" { + return nil, nil + } + if !useProxy(canonicalAddr(req.URL)) { + return nil, nil + } + proxyURL, err := ParseRequestURL(proxy) + if err != nil { + return nil, os.ErrorString("invalid proxy address") + } + if proxyURL.Host == "" { + proxyURL, err = ParseRequestURL("http://" + proxy) + if err != nil { + return nil, os.ErrorString("invalid proxy address") + } + } + return proxyURL, nil +} + +// ProxyURL returns a proxy function (for use in a Transport) +// that always returns the same URL. +func ProxyURL(url *URL) func(*Request) (*URL, os.Error) { + return func(*Request) (*URL, os.Error) { + return url, nil + } +} + // RoundTrip implements the RoundTripper interface. func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { if req.URL == nil { @@ -59,7 +103,16 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { } } if req.URL.Scheme != "http" && req.URL.Scheme != "https" { - return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} + t.lk.Lock() + var rt RoundTripper + if t.altProto != nil { + rt = t.altProto[req.URL.Scheme] + } + t.lk.Unlock() + if rt == nil { + return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} + } + return rt.RoundTrip(req) } cm, err := t.connectMethodForRequest(req) @@ -79,6 +132,27 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { return pconn.roundTrip(req) } +// RegisterProtocol registers a new protocol with scheme. +// The Transport will pass requests using the given scheme to rt. +// It is rt's responsibility to simulate HTTP request semantics. +// +// RegisterProtocol can be used by other packages to provide +// implementations of protocol schemes like "ftp" or "file". +func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { + if scheme == "http" || scheme == "https" { + panic("protocol " + scheme + " already registered") + } + t.lk.Lock() + defer t.lk.Unlock() + if t.altProto == nil { + t.altProto = make(map[string]RoundTripper) + } + if _, exists := t.altProto[scheme]; exists { + panic("protocol " + scheme + " already registered") + } + t.altProto[scheme] = rt +} + // CloseIdleConnections closes any connections which were previously // connected from previous requests but are now sitting idle in // a "keep-alive" state. It does not interrupt any connections currently @@ -101,21 +175,11 @@ func (t *Transport) CloseIdleConnections() { // Private implementation past this point. // -func (t *Transport) getenvEitherCase(k string) string { - if t.IgnoreEnvironment { - return "" - } - if v := t.getenv(strings.ToUpper(k)); v != "" { +func getenvEitherCase(k string) string { + if v := os.Getenv(strings.ToUpper(k)); v != "" { return v } - return t.getenv(strings.ToLower(k)) -} - -func (t *Transport) getenv(k string) string { - if t.IgnoreEnvironment { - return "" - } - return os.Getenv(k) + return os.Getenv(strings.ToLower(k)) } func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) { @@ -123,20 +187,12 @@ func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Er targetScheme: req.URL.Scheme, targetAddr: canonicalAddr(req.URL), } - - proxy := t.getenvEitherCase("HTTP_PROXY") - if proxy != "" && t.useProxy(cm.targetAddr) { - proxyURL, err := ParseRequestURL(proxy) + if t.Proxy != nil { + var err os.Error + cm.proxyURL, err = t.Proxy(req) if err != nil { - return nil, os.ErrorString("invalid proxy address") - } - if proxyURL.Host == "" { - proxyURL, err = ParseRequestURL("http://" + proxy) - if err != nil { - return nil, os.ErrorString("invalid proxy address") - } + return nil, err } - cm.proxyURL = proxyURL } return cm, nil } @@ -149,10 +205,7 @@ func (cm *connectMethod) proxyAuth() string { } proxyInfo := cm.proxyURL.RawUserinfo if proxyInfo != "" { - enc := base64.URLEncoding - encoded := make([]byte, enc.EncodedLen(len(proxyInfo))) - enc.Encode(encoded, []byte(proxyInfo)) - return "Basic " + string(encoded) + return "Basic " + base64.URLEncoding.EncodeToString([]byte(proxyInfo)) } return "" } @@ -207,6 +260,13 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { return } +func (t *Transport) dial(network, addr string) (c net.Conn, err os.Error) { + if t.Dial != nil { + return t.Dial(network, addr) + } + return net.Dial(network, addr) +} + // getConn dials and creates a new persistConn to the target as // specified in the connectMethod. This includes doing a proxy CONNECT // and/or setting up TLS. If this doesn't return an error, the persistConn @@ -216,7 +276,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { return pc, nil } - conn, err := net.Dial("tcp", cm.addr()) + conn, err := t.dial("tcp", cm.addr()) if err != nil { if cm.proxyURL != nil { err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err) @@ -248,18 +308,22 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { } } case cm.targetScheme == "https": - fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\n", cm.targetAddr) - fmt.Fprintf(conn, "Host: %s\r\n", cm.targetAddr) + connectReq := &Request{ + Method: "CONNECT", + RawURL: cm.targetAddr, + Host: cm.targetAddr, + Header: make(Header), + } if pa != "" { - fmt.Fprintf(conn, "Proxy-Authorization: %s\r\n", pa) + connectReq.Header.Set("Proxy-Authorization", pa) } - fmt.Fprintf(conn, "\r\n") + connectReq.Write(conn) // Read response. // Okay to use and discard buffered reader here, because // TLS server will not speak until spoken to. br := bufio.NewReader(conn) - resp, err := ReadResponse(br, "CONNECT") + resp, err := ReadResponse(br, connectReq) if err != nil { conn.Close() return nil, err @@ -285,7 +349,6 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { pconn.br = bufio.NewReader(pconn.conn) pconn.cc = newClientConnFunc(conn, pconn.br) - pconn.cc.readRes = readResponseWithEOFSignal go pconn.readLoop() return pconn, nil } @@ -293,7 +356,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { // useProxy returns true if requests to addr should use a proxy, // according to the NO_PROXY or no_proxy environment variable. // addr is always a canonicalAddr with a host and port. -func (t *Transport) useProxy(addr string) bool { +func useProxy(addr string) bool { if len(addr) == 0 { return true } @@ -305,16 +368,12 @@ func (t *Transport) useProxy(addr string) bool { return false } if ip := net.ParseIP(host); ip != nil { - if ip4 := ip.To4(); ip4 != nil && ip4[0] == 127 { - // 127.0.0.0/8 loopback isn't proxied. - return false - } - if bytes.Equal(ip, net.IPv6loopback) { + if ip.IsLoopback() { return false } } - no_proxy := t.getenvEitherCase("NO_PROXY") + no_proxy := getenvEitherCase("NO_PROXY") if no_proxy == "*" { return false } @@ -447,7 +506,25 @@ func (pc *persistConn) readLoop() { } rc := <-pc.reqch - resp, err := pc.cc.Read(rc.req) + resp, err := pc.cc.readUsing(rc.req, func(buf *bufio.Reader, forReq *Request) (*Response, os.Error) { + resp, err := ReadResponse(buf, forReq) + if err != nil || resp.ContentLength == 0 { + return resp, err + } + if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" { + resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Length") + resp.ContentLength = -1 + gzReader, err := gzip.NewReader(resp.Body) + if err != nil { + pc.close() + return nil, err + } + resp.Body = &readFirstCloseBoth{&discardOnCloseReadCloser{gzReader}, resp.Body} + } + resp.Body = &bodyEOFSignal{body: resp.Body} + return resp, err + }) if err == ErrPersistEOF { // Succeeded, but we can't send any more @@ -469,6 +546,17 @@ func (pc *persistConn) readLoop() { waitForBodyRead <- true } } else { + // When there's no response body, we immediately + // reuse the TCP connection (putIdleConn), but + // we need to prevent ClientConn.Read from + // closing the Response.Body on the next + // loop, otherwise it might close the body + // before the client code has had a chance to + // read it (even though it'll just be 0, EOF). + pc.cc.lk.Lock() + pc.cc.lastbody = nil + pc.cc.lk.Unlock() + pc.t.putIdleConn(pc) } } @@ -491,6 +579,11 @@ type responseAndError struct { type requestAndChan struct { req *Request ch chan responseAndError + + // did the Transport (as opposed to the client code) add an + // Accept-Encoding gzip header? only if it we set it do + // we transparently decode the gzip. + addedGzip bool } func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { @@ -522,25 +615,12 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { } ch := make(chan responseAndError, 1) - pc.reqch <- requestAndChan{req, ch} + pc.reqch <- requestAndChan{req, ch, requestedGzip} re := <-ch pc.lk.Lock() pc.numExpectedResponses-- pc.lk.Unlock() - if re.err == nil && requestedGzip && re.res.Header.Get("Content-Encoding") == "gzip" { - re.res.Header.Del("Content-Encoding") - re.res.Header.Del("Content-Length") - re.res.ContentLength = -1 - esb := re.res.Body.(*bodyEOFSignal) - gzReader, err := gzip.NewReader(esb.body) - if err != nil { - pc.close() - return nil, err - } - esb.body = &readFirstCloseBoth{gzReader, esb.body} - } - return re.res, re.err } @@ -572,16 +652,6 @@ func responseIsKeepAlive(res *Response) bool { return false } -// readResponseWithEOFSignal is a wrapper around ReadResponse that replaces -// the response body with a bodyEOFSignal-wrapped version. -func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) { - resp, err = ReadResponse(r, requestMethod) - if err == nil && resp.ContentLength != 0 { - resp.Body = &bodyEOFSignal{body: resp.Body} - } - return -} - // bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most // once, right before the final Read() or Close() call returns, but after // EOF has been seen. @@ -604,6 +674,9 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) { } func (es *bodyEOFSignal) Close() (err os.Error) { + if es.isClosed { + return nil + } es.isClosed = true err = es.body.Close() if err == nil && es.fn != nil { @@ -628,3 +701,13 @@ func (r *readFirstCloseBoth) Close() os.Error { } return nil } + +// discardOnCloseReadCloser consumes all its input on Close. +type discardOnCloseReadCloser struct { + io.ReadCloser +} + +func (d *discardOnCloseReadCloser) Close() os.Error { + io.Copy(ioutil.Discard, d.ReadCloser) // ignore errors; likely invalid or already closed + return d.ReadCloser.Close() +} diff --git a/src/pkg/http/transport_test.go b/src/pkg/http/transport_test.go index a32ac4c4f..76e97640e 100644 --- a/src/pkg/http/transport_test.go +++ b/src/pkg/http/transport_test.go @@ -17,6 +17,7 @@ import ( "io/ioutil" "os" "strconv" + "strings" "testing" "time" ) @@ -43,7 +44,7 @@ func TestTransportKeepAlives(t *testing.T) { c := &Client{Transport: tr} fetch := func(n int) string { - res, _, err := c.Get(ts.URL) + res, err := c.Get(ts.URL) if err != nil { t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err) } @@ -160,7 +161,7 @@ func TestTransportIdleCacheKeys(t *testing.T) { t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) } - resp, _, err := c.Get(ts.URL) + resp, err := c.Get(ts.URL) if err != nil { t.Error(err) } @@ -201,7 +202,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { // Their responses will hang until we we write to resch, though. donech := make(chan bool) doReq := func() { - resp, _, err := c.Get(ts.URL) + resp, err := c.Get(ts.URL) if err != nil { t.Error(err) } @@ -256,26 +257,44 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { tr := &Transport{} c := &Client{Transport: tr} - fetch := func(n int) string { - res, _, err := c.Get(ts.URL) - if err != nil { - t.Fatalf("error in req #%d, GET: %v", n, err) + fetch := func(n, retries int) string { + condFatalf := func(format string, arg ...interface{}) { + if retries <= 0 { + t.Fatalf(format, arg...) + } + t.Logf("retrying shortly after expected error: "+format, arg...) + time.Sleep(1e9 / int64(retries)) } - body, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatalf("error in req #%d, ReadAll: %v", n, err) + for retries >= 0 { + retries-- + res, err := c.Get(ts.URL) + if err != nil { + condFatalf("error in req #%d, GET: %v", n, err) + continue + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + condFatalf("error in req #%d, ReadAll: %v", n, err) + continue + } + res.Body.Close() + return string(body) } - res.Body.Close() - return string(body) + panic("unreachable") } - body1 := fetch(1) - body2 := fetch(2) + body1 := fetch(1, 0) + body2 := fetch(2, 0) ts.CloseClientConnections() // surprise! - time.Sleep(25e6) // idle for a bit (test is inherently racey, but expectedly) - body3 := fetch(3) + // This test has an expected race. Sleeping for 25 ms prevents + // it on most fast machines, causing the next fetch() call to + // succeed quickly. But if we do get errors, fetch() will retry 5 + // times with some delays between. + time.Sleep(25e6) + + body3 := fetch(3, 5) if body1 != body2 { t.Errorf("expected body1 and body2 to be equal") @@ -376,6 +395,9 @@ func TestTransportGzip(t *testing.T) { t.Errorf("Accept-Encoding = %q, want %q", g, e) } rw.Header().Set("Content-Encoding", "gzip") + if req.Method == "HEAD" { + return + } var w io.Writer = rw var buf bytes.Buffer @@ -399,7 +421,7 @@ func TestTransportGzip(t *testing.T) { c := &Client{Transport: &Transport{}} // First fetch something large, but only read some of it. - res, _, err := c.Get(ts.URL + "?body=large&chunked=" + chunked) + res, err := c.Get(ts.URL + "?body=large&chunked=" + chunked) if err != nil { t.Fatalf("large get: %v", err) } @@ -419,7 +441,7 @@ func TestTransportGzip(t *testing.T) { } // Then something small. - res, _, err = c.Get(ts.URL + "?chunked=" + chunked) + res, err = c.Get(ts.URL + "?chunked=" + chunked) if err != nil { t.Fatal(err) } @@ -445,6 +467,40 @@ func TestTransportGzip(t *testing.T) { t.Errorf("expected Read error after Close; got %d, %v", n, err) } } + + // And a HEAD request too, because they're always weird. + c := &Client{Transport: &Transport{}} + res, err := c.Head(ts.URL) + if err != nil { + t.Fatalf("Head: %v", err) + } + if res.StatusCode != 200 { + t.Errorf("Head status=%d; want=200", res.StatusCode) + } +} + +func TestTransportProxy(t *testing.T) { + ch := make(chan string, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ch <- "real server" + })) + defer ts.Close() + proxy := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ch <- "proxy for " + r.URL.String() + })) + defer proxy.Close() + + pu, err := ParseURL(proxy.URL) + if err != nil { + t.Fatal(err) + } + c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}} + c.Head(ts.URL) + got := <-ch + want := "proxy for " + ts.URL + "/" + if got != want { + t.Errorf("want %q, got %q", want, got) + } } // TestTransportGzipRecursive sends a gzip quine and checks that the @@ -459,7 +515,7 @@ func TestTransportGzipRecursive(t *testing.T) { defer ts.Close() c := &Client{Transport: &Transport{}} - res, _, err := c.Get(ts.URL) + res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } @@ -476,6 +532,36 @@ func TestTransportGzipRecursive(t *testing.T) { } } +type fooProto struct{} + +func (fooProto) RoundTrip(req *Request) (*Response, os.Error) { + res := &Response{ + Status: "200 OK", + StatusCode: 200, + Header: make(Header), + Body: ioutil.NopCloser(strings.NewReader("You wanted " + req.URL.String())), + } + return res, nil +} + +func TestTransportAltProto(t *testing.T) { + tr := &Transport{} + c := &Client{Transport: tr} + tr.RegisterProtocol("foo", fooProto{}) + res, err := c.Get("foo://bar.com/path") + if err != nil { + t.Fatal(err) + } + bodyb, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + body := string(bodyb) + if e := "You wanted foo://bar.com/path"; body != e { + t.Errorf("got response %q, want %q", body, e) + } +} + // rgz is a gzip quine that uncompresses to itself. var rgz = []byte{ 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, diff --git a/src/pkg/http/url.go b/src/pkg/http/url.go index 0fc0cb2d7..05b1662d3 100644 --- a/src/pkg/http/url.go +++ b/src/pkg/http/url.go @@ -449,7 +449,7 @@ func ParseURLReference(rawurlref string) (url *URL, err os.Error) { // // There are redundant fields stored in the URL structure: // the String method consults Scheme, Path, Host, RawUserinfo, -// RawQuery, and Fragment, but not Raw, RawPath or Authority. +// RawQuery, and Fragment, but not Raw, RawPath or RawAuthority. func (url *URL) String() string { result := "" if url.Scheme != "" { @@ -486,10 +486,14 @@ func (url *URL) String() string { return result } -// EncodeQuery encodes the query represented as a multimap. -func EncodeQuery(m map[string][]string) string { - parts := make([]string, 0, len(m)) // will be large enough for most uses - for k, vs := range m { +// Encode encodes the values into ``URL encoded'' form. +// e.g. "foo=bar&bar=baz" +func (v Values) Encode() string { + if v == nil { + return "" + } + parts := make([]string, 0, len(v)) // will be large enough for most uses + for k, vs := range v { prefix := URLEscape(k) + "=" for _, v := range vs { parts = append(parts, prefix+URLEscape(v)) @@ -593,3 +597,9 @@ func (base *URL) ResolveReference(ref *URL) *URL { url.Raw = url.String() return url } + +// Query parses RawQuery and returns the corresponding values. +func (u *URL) Query() Values { + v, _ := ParseQuery(u.RawQuery) + return v +} diff --git a/src/pkg/http/url_test.go b/src/pkg/http/url_test.go index d8863f3d3..eaec5872a 100644 --- a/src/pkg/http/url_test.go +++ b/src/pkg/http/url_test.go @@ -538,23 +538,21 @@ func TestUnescapeUserinfo(t *testing.T) { } } -type qMap map[string][]string - type EncodeQueryTest struct { - m qMap + m Values expected string expected1 string } var encodeQueryTests = []EncodeQueryTest{ {nil, "", ""}, - {qMap{"q": {"puppies"}, "oe": {"utf8"}}, "q=puppies&oe=utf8", "oe=utf8&q=puppies"}, - {qMap{"q": {"dogs", "&", "7"}}, "q=dogs&q=%26&q=7", "q=dogs&q=%26&q=7"}, + {Values{"q": {"puppies"}, "oe": {"utf8"}}, "q=puppies&oe=utf8", "oe=utf8&q=puppies"}, + {Values{"q": {"dogs", "&", "7"}}, "q=dogs&q=%26&q=7", "q=dogs&q=%26&q=7"}, } func TestEncodeQuery(t *testing.T) { for _, tt := range encodeQueryTests { - if q := EncodeQuery(tt.m); q != tt.expected && q != tt.expected1 { + if q := tt.m.Encode(); q != tt.expected && q != tt.expected1 { t.Errorf(`EncodeQuery(%+v) = %q, want %q`, tt.m, q, tt.expected) } } @@ -673,3 +671,28 @@ func TestResolveReference(t *testing.T) { } } + +func TestQueryValues(t *testing.T) { + u, _ := ParseURL("http://x.com?foo=bar&bar=1&bar=2") + v := u.Query() + if len(v) != 2 { + t.Errorf("got %d keys in Query values, want 2", len(v)) + } + if g, e := v.Get("foo"), "bar"; g != e { + t.Errorf("Get(foo) = %q, want %q", g, e) + } + // Case sensitive: + if g, e := v.Get("Foo"), ""; g != e { + t.Errorf("Get(Foo) = %q, want %q", g, e) + } + if g, e := v.Get("bar"), "1"; g != e { + t.Errorf("Get(bar) = %q, want %q", g, e) + } + if g, e := v.Get("baz"), ""; g != e { + t.Errorf("Get(baz) = %q, want %q", g, e) + } + v.Del("bar") + if g, e := v.Get("bar"), ""; g != e { + t.Errorf("second Get(bar) = %q, want %q", g, e) + } +} |