diff options
author | Ondřej Surý <ondrej@sury.org> | 2011-04-20 15:44:41 +0200 |
---|---|---|
committer | Ondřej Surý <ondrej@sury.org> | 2011-04-20 15:44:41 +0200 |
commit | 50104cc32a498f7517a51c8dc93106c51c7a54b4 (patch) | |
tree | 47af80be259cc7c45d0eaec7d42e61fa38c8e4fb /src/pkg/http | |
parent | c072558b90f1bbedc2022b0f30c8b1ac4712538e (diff) | |
download | golang-upstream/2011.03.07.1.tar.gz |
Imported Upstream version 2011.03.07.1upstream/2011.03.07.1
Diffstat (limited to 'src/pkg/http')
29 files changed, 2019 insertions, 594 deletions
diff --git a/src/pkg/http/Makefile b/src/pkg/http/Makefile index 7e4f80c28..389b04222 100644 --- a/src/pkg/http/Makefile +++ b/src/pkg/http/Makefile @@ -8,8 +8,10 @@ TARG=http GOFILES=\ chunked.go\ client.go\ + cookie.go\ dump.go\ fs.go\ + header.go\ lex.go\ persist.go\ request.go\ @@ -17,6 +19,7 @@ GOFILES=\ server.go\ status.go\ transfer.go\ + transport.go\ url.go\ include ../../Make.pkg diff --git a/src/pkg/http/cgi/Makefile b/src/pkg/http/cgi/Makefile new file mode 100644 index 000000000..02f6cfc9e --- /dev/null +++ b/src/pkg/http/cgi/Makefile @@ -0,0 +1,11 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=http/cgi +GOFILES=\ + cgi.go\ + +include ../../../Make.pkg diff --git a/src/pkg/http/cgi/cgi.go b/src/pkg/http/cgi/cgi.go new file mode 100644 index 000000000..dba59efa2 --- /dev/null +++ b/src/pkg/http/cgi/cgi.go @@ -0,0 +1,201 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package cgi implements CGI (Common Gateway Interface) as specified +// in RFC 3875. +// +// Note that using CGI means starting a new process to handle each +// request, which is typically less efficient than using a +// long-running server. This package is intended primarily for +// compatibility with existing systems. +package cgi + +import ( + "encoding/line" + "exec" + "fmt" + "http" + "io" + "log" + "os" + "path" + "regexp" + "strconv" + "strings" +) + +var trailingPort = regexp.MustCompile(`:([0-9]+)$`) + +// Handler runs an executable in a subprocess with a CGI environment. +type Handler struct { + Path string // path to the CGI executable + Root string // root URI prefix of handler or empty for "/" + Env []string // extra environment variables to set, if any + Logger *log.Logger // optional log for errors or nil to use log.Print +} + +func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + root := h.Root + if root == "" { + root = "/" + } + + if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" { + rw.WriteHeader(http.StatusBadRequest) + rw.Write([]byte("Chunked request bodies are not supported by CGI.")) + return + } + + pathInfo := req.URL.Path + if root != "/" && strings.HasPrefix(pathInfo, root) { + pathInfo = pathInfo[len(root):] + } + + port := "80" + if matches := trailingPort.FindStringSubmatch(req.Host); len(matches) != 0 { + port = matches[1] + } + + env := []string{ + "SERVER_SOFTWARE=go", + "SERVER_NAME=" + req.Host, + "HTTP_HOST=" + req.Host, + "GATEWAY_INTERFACE=CGI/1.1", + "REQUEST_METHOD=" + req.Method, + "QUERY_STRING=" + req.URL.RawQuery, + "REQUEST_URI=" + req.URL.RawPath, + "PATH_INFO=" + pathInfo, + "SCRIPT_NAME=" + root, + "SCRIPT_FILENAME=" + h.Path, + "REMOTE_ADDR=" + rw.RemoteAddr(), + "REMOTE_HOST=" + rw.RemoteAddr(), + "SERVER_PORT=" + port, + } + + for k, _ := range req.Header { + k = strings.Map(upperCaseAndUnderscore, k) + env = append(env, "HTTP_"+k+"="+req.Header.Get(k)) + } + + if req.ContentLength > 0 { + env = append(env, fmt.Sprintf("CONTENT_LENGTH=%d", req.ContentLength)) + } + if ctype := req.Header.Get("Content-Type"); ctype != "" { + env = append(env, "CONTENT_TYPE="+ctype) + } + + if h.Env != nil { + env = append(env, h.Env...) + } + + // TODO: use filepath instead of path when available + cwd, pathBase := path.Split(h.Path) + if cwd == "" { + cwd = "." + } + + cmd, err := exec.Run( + pathBase, + []string{h.Path}, + env, + cwd, + exec.Pipe, // stdin + exec.Pipe, // stdout + exec.PassThrough, // stderr (for now) + ) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("CGI error: %v", err) + return + } + defer func() { + cmd.Stdin.Close() + cmd.Stdout.Close() + cmd.Wait(0) // no zombies + }() + + if req.ContentLength != 0 { + go io.Copy(cmd.Stdin, req.Body) + } + + linebody := line.NewReader(cmd.Stdout, 1024) + headers := make(map[string]string) + statusCode := http.StatusOK + for { + line, isPrefix, err := linebody.ReadLine() + if isPrefix { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("CGI: long header line from subprocess.") + return + } + if err == os.EOF { + break + } + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("CGI: error reading headers: %v", err) + return + } + if len(line) == 0 { + break + } + parts := strings.Split(string(line), ":", 2) + if len(parts) < 2 { + h.printf("CGI: bogus header line: %s", string(line)) + continue + } + header, val := parts[0], parts[1] + header = strings.TrimSpace(header) + val = strings.TrimSpace(val) + switch { + case header == "Status": + if len(val) < 3 { + h.printf("CGI: bogus status (short): %q", val) + return + } + code, err := strconv.Atoi(val[0:3]) + if err != nil { + h.printf("CGI: bogus status: %q", val) + h.printf("CGI: line was %q", line) + return + } + statusCode = code + default: + headers[header] = val + } + } + for h, v := range headers { + rw.SetHeader(h, v) + } + rw.WriteHeader(statusCode) + + _, err = io.Copy(rw, linebody) + if err != nil { + h.printf("CGI: copy error: %v", err) + } +} + +func (h *Handler) printf(format string, v ...interface{}) { + if h.Logger != nil { + h.Logger.Printf(format, v...) + } else { + log.Printf(format, v...) + } +} + +func upperCaseAndUnderscore(rune int) int { + switch { + case rune >= 'a' && rune <= 'z': + return rune - ('a' - 'A') + case rune == '-': + return '_' + case rune == '=': + // Maybe not part of the CGI 'spec' but would mess up + // the environment in any case, as Go represents the + // environment as a slice of "key=value" strings. + return '_' + } + // TODO: other transformations in spec or practice? + return rune +} diff --git a/src/pkg/http/cgi/cgi_test.go b/src/pkg/http/cgi/cgi_test.go new file mode 100644 index 000000000..daf9a2cb3 --- /dev/null +++ b/src/pkg/http/cgi/cgi_test.go @@ -0,0 +1,247 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for package cgi + +package cgi + +import ( + "bufio" + "exec" + "fmt" + "http" + "http/httptest" + "os" + "strings" + "testing" +) + +var cgiScriptWorks = canRun("./testdata/test.cgi") + +func canRun(s string) bool { + c, err := exec.Run(s, []string{s}, nil, ".", exec.DevNull, exec.DevNull, exec.DevNull) + if err != nil { + return false + } + w, err := c.Wait(0) + if err != nil { + return false + } + return w.Exited() && w.ExitStatus() == 0 +} + +func newRequest(httpreq string) *http.Request { + buf := bufio.NewReader(strings.NewReader(httpreq)) + req, err := http.ReadRequest(buf) + if err != nil { + panic("cgi: bogus http request in test: " + httpreq) + } + return req +} + +func runCgiTest(t *testing.T, h *Handler, httpreq string, expectedMap map[string]string) *httptest.ResponseRecorder { + rw := httptest.NewRecorder() + req := newRequest(httpreq) + h.ServeHTTP(rw, req) + + // Make a map to hold the test map that the CGI returns. + m := make(map[string]string) +readlines: + for { + line, err := rw.Body.ReadString('\n') + switch { + case err == os.EOF: + break readlines + case err != nil: + t.Fatalf("unexpected error reading from CGI: %v", err) + } + line = strings.TrimRight(line, "\r\n") + split := strings.Split(line, "=", 2) + if len(split) != 2 { + t.Fatalf("Unexpected %d parts from invalid line: %q", len(split), line) + } + m[split[0]] = split[1] + } + + for key, expected := range expectedMap { + if got := m[key]; got != expected { + t.Errorf("for key %q got %q; expected %q", key, got, expected) + } + } + return rw +} + +func skipTest(t *testing.T) bool { + if !cgiScriptWorks { + // No Perl on Windows, needed by test.cgi + // TODO: make the child process be Go, not Perl. + t.Logf("Skipping test: test.cgi failed.") + return true + } + return false +} + + +func TestCGIBasicGet(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "test": "Hello CGI", + "param-a": "b", + "param-foo": "bar", + "env-GATEWAY_INTERFACE": "CGI/1.1", + "env-HTTP_HOST": "example.com", + "env-PATH_INFO": "", + "env-QUERY_STRING": "foo=bar&a=b", + "env-REMOTE_ADDR": "1.2.3.4", + "env-REMOTE_HOST": "1.2.3.4", + "env-REQUEST_METHOD": "GET", + "env-REQUEST_URI": "/test.cgi?foo=bar&a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/test.cgi", + "env-SERVER_NAME": "example.com", + "env-SERVER_PORT": "80", + "env-SERVER_SOFTWARE": "go", + } + replay := runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) + + if expected, got := "text/html", replay.Header.Get("Content-Type"); got != expected { + t.Errorf("got a Content-Type of %q; expected %q", got, expected) + } + if expected, got := "X-Test-Value", replay.Header.Get("X-Test-Header"); got != expected { + t.Errorf("got a X-Test-Header of %q; expected %q", got, expected) + } +} + +func TestCGIBasicGetAbsPath(t *testing.T) { + if skipTest(t) { + return + } + pwd, err := os.Getwd() + if err != nil { + t.Fatalf("getwd error: %v", err) + } + h := &Handler{ + Path: pwd + "/testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "env-REQUEST_URI": "/test.cgi?foo=bar&a=b", + "env-SCRIPT_FILENAME": pwd + "/testdata/test.cgi", + "env-SCRIPT_NAME": "/test.cgi", + } + runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestPathInfo(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "param-a": "b", + "env-PATH_INFO": "/extrapath", + "env-QUERY_STRING": "a=b", + "env-REQUEST_URI": "/test.cgi/extrapath?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/test.cgi", + } + runCgiTest(t, h, "GET /test.cgi/extrapath?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestPathInfoDirRoot(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/myscript/", + } + expectedMap := map[string]string{ + "env-PATH_INFO": "bar", + "env-QUERY_STRING": "a=b", + "env-REQUEST_URI": "/myscript/bar?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/myscript/", + } + runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestPathInfoNoRoot(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "", + } + expectedMap := map[string]string{ + "env-PATH_INFO": "/bar", + "env-QUERY_STRING": "a=b", + "env-REQUEST_URI": "/bar?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/", + } + runCgiTest(t, h, "GET /bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestCGIBasicPost(t *testing.T) { + if skipTest(t) { + return + } + postReq := `POST /test.cgi?a=b HTTP/1.0 +Host: example.com +Content-Type: application/x-www-form-urlencoded +Content-Length: 15 + +postfoo=postbar` + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "test": "Hello CGI", + "param-postfoo": "postbar", + "env-REQUEST_METHOD": "POST", + "env-CONTENT_LENGTH": "15", + "env-REQUEST_URI": "/test.cgi?a=b", + } + runCgiTest(t, h, postReq, expectedMap) +} + +func chunk(s string) string { + return fmt.Sprintf("%x\r\n%s\r\n", len(s), s) +} + +// The CGI spec doesn't allow chunked requests. +func TestCGIPostChunked(t *testing.T) { + if skipTest(t) { + return + } + postReq := `POST /test.cgi?a=b HTTP/1.1 +Host: example.com +Content-Type: application/x-www-form-urlencoded +Transfer-Encoding: chunked + +` + chunk("postfoo") + chunk("=") + chunk("postbar") + chunk("") + + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{} + resp := runCgiTest(t, h, postReq, expectedMap) + if got, expected := resp.Code, http.StatusBadRequest; got != expected { + t.Fatalf("Expected %v response code from chunked request body; got %d", + expected, got) + } +} diff --git a/src/pkg/http/cgi/testdata/test.cgi b/src/pkg/http/cgi/testdata/test.cgi new file mode 100755 index 000000000..b931b04c5 --- /dev/null +++ b/src/pkg/http/cgi/testdata/test.cgi @@ -0,0 +1,34 @@ +#!/usr/bin/perl +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. +# +# Test script run as a child process under cgi_test.go + +use strict; +use CGI; + +my $q = CGI->new; +my $params = $q->Vars; + +my $NL = "\r\n"; +$NL = "\n" if 1 || $params->{mode} eq "NL"; + +my $p = sub { + print "$_[0]$NL"; +}; + +# With carriage returns +$p->("Content-Type: text/html"); +$p->("X-Test-Header: X-Test-Value"); +$p->(""); + +print "test=Hello CGI\n"; + +foreach my $k (sort keys %$params) { + print "param-$k=$params->{$k}\n"; +} + +foreach my $k (sort keys %ENV) { + print "env-$k=$ENV{$k}\n"; +} diff --git a/src/pkg/http/client.go b/src/pkg/http/client.go index 022f4f124..c24eea581 100644 --- a/src/pkg/http/client.go +++ b/src/pkg/http/client.go @@ -7,18 +7,41 @@ package http import ( - "bufio" "bytes" - "crypto/tls" "encoding/base64" "fmt" "io" - "net" "os" "strconv" "strings" ) +// A Client is an HTTP client. Its zero value (DefaultClient) is a usable client +// that uses DefaultTransport. +// Client is not yet very configurable. +type Client struct { + Transport Transport // if nil, DefaultTransport is used +} + +// DefaultClient is the default Client and is used by Get, Head, and Post. +var DefaultClient = &Client{} + +// Transport is an interface representing the ability to execute a +// single HTTP transaction, obtaining the Response for a given Request. +type Transport interface { + // Do executes a single HTTP transaction, returning the Response for the + // request req. Do should not attempt to interpret the response. + // In particular, Do must return err == nil if it obtained a response, + // regardless of the response's HTTP status code. A non-nil err should + // be reserved for failure to obtain a response. Similarly, Do should + // not attempt to handle higher-level protocol details such as redirects, + // authentication, or cookies. + // + // Transports may modify the request. The request Headers field is + // guaranteed to be initalized. + Do(req *Request) (resp *Response, err os.Error) +} + // Given a string of the form "host", "host:port", or "[ipv6::address]:port", // return true if the string includes a port. func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") } @@ -31,67 +54,83 @@ type readClose struct { io.Closer } -// Send issues an HTTP request. Caller should close resp.Body when done reading it. +// matchNoProxy returns true if requests to addr should not use a proxy, +// according to the NO_PROXY or no_proxy environment variable. +func matchNoProxy(addr string) bool { + if len(addr) == 0 { + return false + } + no_proxy := os.Getenv("NO_PROXY") + if len(no_proxy) == 0 { + no_proxy = os.Getenv("no_proxy") + } + if no_proxy == "*" { + return true + } + + addr = strings.ToLower(strings.TrimSpace(addr)) + if hasPort(addr) { + addr = addr[:strings.LastIndex(addr, ":")] + } + + for _, p := range strings.Split(no_proxy, ",", -1) { + p = strings.ToLower(strings.TrimSpace(p)) + if len(p) == 0 { + continue + } + if hasPort(p) { + p = p[:strings.LastIndex(p, ":")] + } + if addr == p || (p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:])) { + return true + } + } + return false +} + +// Do sends an HTTP request and returns an HTTP response, following +// policy (e.g. redirects, cookies, auth) as configured on the client. +// +// Callers should close resp.Body when done reading from it. +// +// Generally Get, Post, or PostForm will be used instead of Do. +func (c *Client) Do(req *Request) (resp *Response, err os.Error) { + return send(req, c.Transport) +} + + +// send issues an HTTP request. Caller should close resp.Body when done reading from it. // // TODO: support persistent connections (multiple requests on a single connection). // send() method is nonpublic because, when we refactor the code for persistent // connections, it may no longer make sense to have a method with this signature. -func send(req *Request) (resp *Response, err os.Error) { - if req.URL.Scheme != "http" && req.URL.Scheme != "https" { - return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} +func send(req *Request, t Transport) (resp *Response, err os.Error) { + if t == nil { + t = DefaultTransport + if t == nil { + err = os.NewError("no http.Client.Transport or http.DefaultTransport") + return + } } - addr := req.URL.Host - if !hasPort(addr) { - addr += ":" + req.URL.Scheme + // Most the callers of send (Get, Post, et al) don't need + // Headers, leaving it uninitialized. We guarantee to the + // Transport that this has been initialized, though. + if req.Header == nil { + req.Header = make(Header) } + info := req.URL.RawUserinfo if len(info) > 0 { enc := base64.URLEncoding encoded := make([]byte, enc.EncodedLen(len(info))) enc.Encode(encoded, []byte(info)) if req.Header == nil { - req.Header = make(map[string]string) + req.Header = make(Header) } - req.Header["Authorization"] = "Basic " + string(encoded) - } - - var conn io.ReadWriteCloser - if req.URL.Scheme == "http" { - conn, err = net.Dial("tcp", "", addr) - if err != nil { - return nil, err - } - } else { // https - conn, err = tls.Dial("tcp", "", addr, nil) - if err != nil { - return nil, err - } - h := req.URL.Host - if hasPort(h) { - h = h[0:strings.LastIndex(h, ":")] - } - if err := conn.(*tls.Conn).VerifyHostname(h); err != nil { - return nil, err - } - } - - err = req.Write(conn) - if err != nil { - conn.Close() - return nil, err + req.Header.Set("Authorization", "Basic "+string(encoded)) } - - reader := bufio.NewReader(conn) - resp, err = ReadResponse(reader, req.Method) - if err != nil { - conn.Close() - return nil, err - } - - resp.Body = readClose{resp.Body, conn} - - return + return t.Do(req) } // True if the specified HTTP status code is one for which the Get utility should @@ -115,12 +154,32 @@ func shouldRedirect(statusCode int) bool { // finalURL is the URL from which the response was fetched -- identical to the // input URL unless redirects were followed. // -// Caller should close r.Body when done reading it. +// Caller should close r.Body when done reading from it. +// +// Get is a convenience wrapper around DefaultClient.Get. func Get(url string) (r *Response, finalURL string, err os.Error) { + return DefaultClient.Get(url) +} + +// Get issues a GET to the specified URL. If the response is one of the following +// redirect codes, it follows the redirect, up to a maximum of 10 redirects: +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// +// finalURL is the URL from which the response was fetched -- identical to the +// input URL unless redirects were followed. +// +// Caller should close r.Body when done reading from it. +func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { // TODO: if/when we add cookie support, the redirected request shouldn't // necessarily supply the same cookies as the original. // TODO: set referrer header on redirects. var base *URL + // TODO: remove this hard-coded 10 and use the Client's policy + // (ClientConfig) instead. for redirect := 0; ; redirect++ { if redirect >= 10 { err = os.ErrorString("stopped after 10 redirects") @@ -128,6 +187,9 @@ func Get(url string) (r *Response, finalURL string, err os.Error) { } var req Request + req.Method = "GET" + req.ProtoMajor = 1 + req.ProtoMinor = 1 if base == nil { req.URL, err = ParseURL(url) } else { @@ -137,12 +199,12 @@ func Get(url string) (r *Response, finalURL string, err os.Error) { break } url = req.URL.String() - if r, err = send(&req); err != nil { + if r, err = send(&req, c.Transport); err != nil { break } if shouldRedirect(r.StatusCode) { r.Body.Close() - if url = r.GetHeader("Location"); url == "" { + if url = r.Header.Get("Location"); url == "" { err = os.ErrorString(fmt.Sprintf("%d response missing Location header", r.StatusCode)) break } @@ -159,16 +221,25 @@ func Get(url string) (r *Response, finalURL string, err os.Error) { // Post issues a POST to the specified URL. // -// Caller should close r.Body when done reading it. +// Caller should close r.Body when done reading from it. +// +// Post is a wrapper around DefaultClient.Post func Post(url string, bodyType string, body io.Reader) (r *Response, err os.Error) { + return DefaultClient.Post(url, bodyType, body) +} + +// Post issues a POST to the specified URL. +// +// Caller should close r.Body when done reading from it. +func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response, err os.Error) { var req Request req.Method = "POST" req.ProtoMajor = 1 req.ProtoMinor = 1 req.Close = true req.Body = nopCloser{body} - req.Header = map[string]string{ - "Content-Type": bodyType, + req.Header = Header{ + "Content-Type": {bodyType}, } req.TransferEncoding = []string{"chunked"} @@ -177,14 +248,24 @@ func Post(url string, bodyType string, body io.Reader) (r *Response, err os.Erro return nil, err } - return send(&req) + return send(&req, c.Transport) } // PostForm issues a POST to the specified URL, // with data's keys and values urlencoded as the request body. // -// Caller should close r.Body when done reading it. +// Caller should close r.Body when done reading from it. +// +// PostForm is a wrapper around DefaultClient.PostForm func PostForm(url string, data map[string]string) (r *Response, err os.Error) { + return DefaultClient.PostForm(url, data) +} + +// PostForm issues a POST to the specified URL, +// with data's keys and values urlencoded as the request body. +// +// Caller should close r.Body when done reading from it. +func (c *Client) PostForm(url string, data map[string]string) (r *Response, err os.Error) { var req Request req.Method = "POST" req.ProtoMajor = 1 @@ -192,9 +273,9 @@ func PostForm(url string, data map[string]string) (r *Response, err os.Error) { req.Close = true body := urlencode(data) req.Body = nopCloser{body} - req.Header = map[string]string{ - "Content-Type": "application/x-www-form-urlencoded", - "Content-Length": strconv.Itoa(body.Len()), + req.Header = Header{ + "Content-Type": {"application/x-www-form-urlencoded"}, + "Content-Length": {strconv.Itoa(body.Len())}, } req.ContentLength = int64(body.Len()) @@ -203,7 +284,7 @@ func PostForm(url string, data map[string]string) (r *Response, err os.Error) { return nil, err } - return send(&req) + return send(&req, c.Transport) } // TODO: remove this function when PostForm takes a multimap. @@ -216,17 +297,20 @@ func urlencode(data map[string]string) (b *bytes.Buffer) { } // Head issues a HEAD to the specified URL. +// +// Head is a wrapper around DefaultClient.Head func Head(url string) (r *Response, err os.Error) { + return DefaultClient.Head(url) +} + +// Head issues a HEAD to the specified URL. +func (c *Client) Head(url string) (r *Response, err os.Error) { var req Request req.Method = "HEAD" if req.URL, err = ParseURL(url); err != nil { return } - url = req.URL.String() - if r, err = send(&req); err != nil { - return - } - return + return send(&req, c.Transport) } type nopCloser struct { diff --git a/src/pkg/http/client_test.go b/src/pkg/http/client_test.go index 013653a82..c89ecbce2 100644 --- a/src/pkg/http/client_test.go +++ b/src/pkg/http/client_test.go @@ -8,6 +8,7 @@ package http import ( "io/ioutil" + "os" "strings" "testing" ) @@ -38,3 +39,28 @@ func TestClientHead(t *testing.T) { t.Error("Last-Modified header not found.") } } + +type recordingTransport struct { + req *Request +} + +func (t *recordingTransport) Do(req *Request) (resp *Response, err os.Error) { + t.req = req + return nil, os.NewError("dummy impl") +} + +func TestGetRequestFormat(t *testing.T) { + tr := &recordingTransport{} + client := &Client{Transport: tr} + url := "http://dummy.faketld/" + client.Get(url) // Note: doesn't hit network + if tr.req.Method != "GET" { + t.Errorf("expected method %q; got %q", "GET", tr.req.Method) + } + if tr.req.URL.String() != url { + t.Errorf("expected URL %q; got %q", url, tr.req.URL.String()) + } + if tr.req.Header == nil { + t.Errorf("expected non-nil request Header") + } +} diff --git a/src/pkg/http/cookie.go b/src/pkg/http/cookie.go new file mode 100644 index 000000000..ff75c47c9 --- /dev/null +++ b/src/pkg/http/cookie.go @@ -0,0 +1,336 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "fmt" + "io" + "os" + "sort" + "strconv" + "strings" + "time" +) + +// A note on Version=0 vs. Version=1 cookies +// +// The difference between Set-Cookie and Set-Cookie2 is hard to discern from the +// RFCs as it is not stated explicitly. There seem to be three standards +// lingering on the web: Netscape, RFC 2109 (aka Version=0) and RFC 2965 (aka +// Version=1). It seems that Netscape and RFC 2109 are the same thing, hereafter +// Version=0 cookies. +// +// In general, Set-Cookie2 is a superset of Set-Cookie. It has a few new +// attributes like HttpOnly and Secure. To be meticulous, if a server intends +// to use these, it needs to send a Set-Cookie2. However, it is most likely +// most modern browsers will not complain seeing an HttpOnly attribute in a +// Set-Cookie header. +// +// Both RFC 2109 and RFC 2965 use Cookie in the same way - two send cookie +// values from clients to servers - and the allowable attributes seem to be the +// same. +// +// The Cookie2 header is used for a different purpose. If a client suspects that +// the server speaks Version=1 (RFC 2965) then along with the Cookie header +// lines, you can also send: +// +// Cookie2: $Version="1" +// +// in order to suggest to the server that you understand Version=1 cookies. At +// which point the server may continue responding with Set-Cookie2 headers. +// When a client sends the (above) Cookie2 header line, it must be prepated to +// understand incoming Set-Cookie2. +// +// This implementation of cookies supports neither Set-Cookie2 nor Cookie2 +// headers. However, it parses Version=1 Cookies (along with Version=0) as well +// as Set-Cookie headers which utilize the full Set-Cookie2 syntax. + +// TODO(petar): Explicitly forbid parsing of Set-Cookie attributes +// starting with '$', which have been used to hack into broken +// servers using the eventual Request headers containing those +// invalid attributes that may overwrite intended $Version, $Path, +// etc. attributes. +// TODO(petar): Read 'Set-Cookie2' headers and prioritize them over equivalent +// 'Set-Cookie' headers. 'Set-Cookie2' headers are still extremely rare. + +// A Cookie represents an RFC 2965 HTTP cookie as sent in +// the Set-Cookie header of an HTTP response or the Cookie header +// of an HTTP request. +// The Set-Cookie2 and Cookie2 headers are unimplemented. +type Cookie struct { + Name string + Value string + Path string + Domain string + Comment string + Version int + Expires time.Time + RawExpires string + MaxAge int // Max age in seconds + Secure bool + HttpOnly bool + Raw string + Unparsed []string // Raw text of unparsed attribute-value pairs +} + +// readSetCookies parses all "Set-Cookie" values from +// the header h, removes the successfully parsed values from the +// "Set-Cookie" key in h and returns the parsed Cookies. +func readSetCookies(h Header) []*Cookie { + cookies := []*Cookie{} + var unparsedLines []string + for _, line := range h["Set-Cookie"] { + parts := strings.Split(strings.TrimSpace(line), ";", -1) + if len(parts) == 1 && parts[0] == "" { + continue + } + parts[0] = strings.TrimSpace(parts[0]) + j := strings.Index(parts[0], "=") + if j < 0 { + unparsedLines = append(unparsedLines, line) + continue + } + name, value := parts[0][:j], parts[0][j+1:] + value, err := URLUnescape(value) + if err != nil { + unparsedLines = append(unparsedLines, line) + continue + } + c := &Cookie{ + Name: name, + Value: value, + MaxAge: -1, // Not specified + Raw: line, + } + for i := 1; i < len(parts); i++ { + parts[i] = strings.TrimSpace(parts[i]) + if len(parts[i]) == 0 { + continue + } + + attr, val := parts[i], "" + if j := strings.Index(attr, "="); j >= 0 { + attr, val = attr[:j], attr[j+1:] + val, err = URLUnescape(val) + if err != nil { + c.Unparsed = append(c.Unparsed, parts[i]) + continue + } + } + switch strings.ToLower(attr) { + case "secure": + c.Secure = true + continue + case "httponly": + c.HttpOnly = true + continue + case "comment": + c.Comment = val + continue + case "domain": + c.Domain = val + // TODO: Add domain parsing + continue + case "max-age": + secs, err := strconv.Atoi(val) + if err != nil || secs < 0 { + break + } + c.MaxAge = secs + continue + case "expires": + c.RawExpires = val + exptime, err := time.Parse(time.RFC1123, val) + if err != nil { + c.Expires = time.Time{} + break + } + c.Expires = *exptime + continue + case "path": + c.Path = val + // TODO: Add path parsing + continue + case "version": + c.Version, err = strconv.Atoi(val) + if err != nil { + c.Version = 0 + break + } + continue + } + c.Unparsed = append(c.Unparsed, parts[i]) + } + cookies = append(cookies, c) + } + h["Set-Cookie"] = unparsedLines, unparsedLines != nil + return cookies +} + +// writeSetCookies writes the wire representation of the set-cookies +// to w. Each cookie is written on a separate "Set-Cookie: " line. +// This choice is made because HTTP parsers tend to have a limit on +// line-length, so it seems safer to place cookies on separate lines. +func writeSetCookies(w io.Writer, kk []*Cookie) os.Error { + if kk == nil { + return nil + } + lines := make([]string, 0, len(kk)) + var b bytes.Buffer + for _, c := range kk { + b.Reset() + // TODO(petar): c.Value (below) should be unquoted if it is recognized as quoted + fmt.Fprintf(&b, "%s=%s", CanonicalHeaderKey(c.Name), c.Value) + if c.Version > 0 { + fmt.Fprintf(&b, "Version=%d; ", c.Version) + } + if len(c.Path) > 0 { + fmt.Fprintf(&b, "; Path=%s", URLEscape(c.Path)) + } + if len(c.Domain) > 0 { + fmt.Fprintf(&b, "; Domain=%s", URLEscape(c.Domain)) + } + if len(c.Expires.Zone) > 0 { + fmt.Fprintf(&b, "; Expires=%s", c.Expires.Format(time.RFC1123)) + } + if c.MaxAge >= 0 { + fmt.Fprintf(&b, "; Max-Age=%d", c.MaxAge) + } + if c.HttpOnly { + fmt.Fprintf(&b, "; HttpOnly") + } + if c.Secure { + fmt.Fprintf(&b, "; Secure") + } + if len(c.Comment) > 0 { + fmt.Fprintf(&b, "; Comment=%s", URLEscape(c.Comment)) + } + lines = append(lines, "Set-Cookie: "+b.String()+"\r\n") + } + sort.SortStrings(lines) + for _, l := range lines { + if _, err := io.WriteString(w, l); err != nil { + return err + } + } + return nil +} + +// readCookies parses all "Cookie" values from +// the header h, removes the successfully parsed values from the +// "Cookie" key in h and returns the parsed Cookies. +func readCookies(h Header) []*Cookie { + cookies := []*Cookie{} + lines, ok := h["Cookie"] + if !ok { + return cookies + } + unparsedLines := []string{} + for _, line := range lines { + parts := strings.Split(strings.TrimSpace(line), ";", -1) + if len(parts) == 1 && parts[0] == "" { + continue + } + // Per-line attributes + var lineCookies = make(map[string]string) + var version int + var path string + var domain string + var comment string + var httponly bool + for i := 0; i < len(parts); i++ { + parts[i] = strings.TrimSpace(parts[i]) + if len(parts[i]) == 0 { + continue + } + attr, val := parts[i], "" + var err os.Error + if j := strings.Index(attr, "="); j >= 0 { + attr, val = attr[:j], attr[j+1:] + val, err = URLUnescape(val) + if err != nil { + continue + } + } + switch strings.ToLower(attr) { + case "$httponly": + httponly = true + case "$version": + version, err = strconv.Atoi(val) + if err != nil { + version = 0 + continue + } + case "$domain": + domain = val + // TODO: Add domain parsing + case "$path": + path = val + // TODO: Add path parsing + case "$comment": + comment = val + default: + lineCookies[attr] = val + } + } + if len(lineCookies) == 0 { + unparsedLines = append(unparsedLines, line) + } + for n, v := range lineCookies { + cookies = append(cookies, &Cookie{ + Name: n, + Value: v, + Path: path, + Domain: domain, + Comment: comment, + Version: version, + HttpOnly: httponly, + MaxAge: -1, + Raw: line, + }) + } + } + h["Cookie"] = unparsedLines, len(unparsedLines) > 0 + return cookies +} + +// writeCookies writes the wire representation of the cookies +// to w. Each cookie is written on a separate "Cookie: " line. +// This choice is made because HTTP parsers tend to have a limit on +// line-length, so it seems safer to place cookies on separate lines. +func writeCookies(w io.Writer, kk []*Cookie) os.Error { + lines := make([]string, 0, len(kk)) + var b bytes.Buffer + for _, c := range kk { + b.Reset() + n := c.Name + if c.Version > 0 { + fmt.Fprintf(&b, "$Version=%d; ", c.Version) + } + // TODO(petar): c.Value (below) should be unquoted if it is recognized as quoted + fmt.Fprintf(&b, "%s=%s", CanonicalHeaderKey(n), c.Value) + if len(c.Path) > 0 { + fmt.Fprintf(&b, "; $Path=%s", URLEscape(c.Path)) + } + if len(c.Domain) > 0 { + fmt.Fprintf(&b, "; $Domain=%s", URLEscape(c.Domain)) + } + if c.HttpOnly { + fmt.Fprintf(&b, "; $HttpOnly") + } + if len(c.Comment) > 0 { + fmt.Fprintf(&b, "; $Comment=%s", URLEscape(c.Comment)) + } + lines = append(lines, "Cookie: "+b.String()+"\r\n") + } + sort.SortStrings(lines) + for _, l := range lines { + if _, err := io.WriteString(w, l); err != nil { + return err + } + } + return nil +} diff --git a/src/pkg/http/cookie_test.go b/src/pkg/http/cookie_test.go new file mode 100644 index 000000000..363c841bb --- /dev/null +++ b/src/pkg/http/cookie_test.go @@ -0,0 +1,96 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "reflect" + "testing" +) + + +var writeSetCookiesTests = []struct { + Cookies []*Cookie + Raw string +}{ + { + []*Cookie{&Cookie{Name: "cookie-1", Value: "v$1", MaxAge: -1}}, + "Set-Cookie: Cookie-1=v$1\r\n", + }, +} + +func TestWriteSetCookies(t *testing.T) { + for i, tt := range writeSetCookiesTests { + var w bytes.Buffer + writeSetCookies(&w, tt.Cookies) + seen := string(w.Bytes()) + if seen != tt.Raw { + t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.Raw, seen) + continue + } + } +} + +var writeCookiesTests = []struct { + Cookies []*Cookie + Raw string +}{ + { + []*Cookie{&Cookie{Name: "cookie-1", Value: "v$1", MaxAge: -1}}, + "Cookie: Cookie-1=v$1\r\n", + }, +} + +func TestWriteCookies(t *testing.T) { + for i, tt := range writeCookiesTests { + var w bytes.Buffer + writeCookies(&w, tt.Cookies) + seen := string(w.Bytes()) + if seen != tt.Raw { + t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.Raw, seen) + continue + } + } +} + +var readSetCookiesTests = []struct { + Header Header + Cookies []*Cookie +}{ + { + Header{"Set-Cookie": {"Cookie-1=v$1"}}, + []*Cookie{&Cookie{Name: "Cookie-1", Value: "v$1", MaxAge: -1, Raw: "Cookie-1=v$1"}}, + }, +} + +func TestReadSetCookies(t *testing.T) { + for i, tt := range readSetCookiesTests { + c := readSetCookies(tt.Header) + if !reflect.DeepEqual(c, tt.Cookies) { + t.Errorf("#%d readSetCookies: have\n%#v\nwant\n%#v\n", i, c, tt.Cookies) + continue + } + } +} + +var readCookiesTests = []struct { + Header Header + Cookies []*Cookie +}{ + { + Header{"Cookie": {"Cookie-1=v$1"}}, + []*Cookie{&Cookie{Name: "Cookie-1", Value: "v$1", MaxAge: -1, Raw: "Cookie-1=v$1"}}, + }, +} + +func TestReadCookies(t *testing.T) { + for i, tt := range readCookiesTests { + c := readCookies(tt.Header) + if !reflect.DeepEqual(c, tt.Cookies) { + t.Errorf("#%d readCookies: have\n%#v\nwant\n%#v\n", i, c, tt.Cookies) + continue + } + } +} diff --git a/src/pkg/http/fs.go b/src/pkg/http/fs.go index bbfa58d26..a4cd7072e 100644 --- a/src/pkg/http/fs.go +++ b/src/pkg/http/fs.go @@ -11,7 +11,7 @@ import ( "io" "mime" "os" - "path" + "path/filepath" "strconv" "strings" "time" @@ -104,7 +104,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { } } - if t, _ := time.Parse(TimeFormat, r.Header["If-Modified-Since"]); t != nil && d.Mtime_ns/1e9 <= t.Seconds() { + if t, _ := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); t != nil && d.Mtime_ns/1e9 <= t.Seconds() { w.WriteHeader(StatusNotModified) return } @@ -112,7 +112,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { // use contents of index.html for directory, if present if d.IsDirectory() { - index := name + indexPage + index := name + filepath.FromSlash(indexPage) ff, err := os.Open(index, os.O_RDONLY, 0) if err == nil { defer ff.Close() @@ -135,7 +135,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { code := StatusOK // use extension to find content type. - ext := path.Ext(name) + ext := filepath.Ext(name) if ctype := mime.TypeByExtension(ext); ctype != "" { w.SetHeader("Content-Type", ctype) } else { @@ -153,7 +153,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { // handle Content-Range header. // TODO(adg): handle multiple ranges - ranges, err := parseRange(r.Header["Range"], size) + ranges, err := parseRange(r.Header.Get("Range"), size) if err != nil || len(ranges) > 1 { Error(w, err.String(), StatusRequestedRangeNotSatisfiable) return @@ -202,7 +202,7 @@ func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) { return } path = path[len(f.prefix):] - serveFile(w, r, f.root+"/"+path, true) + serveFile(w, r, filepath.Join(f.root, filepath.FromSlash(path)), true) } // httpRange specifies the byte range to be sent to the client. diff --git a/src/pkg/http/fs_test.go b/src/pkg/http/fs_test.go index 0a5636b88..a89c76d0b 100644 --- a/src/pkg/http/fs_test.go +++ b/src/pkg/http/fs_test.go @@ -2,89 +2,22 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http +package http_test import ( "fmt" + . "http" + "http/httptest" "io/ioutil" - "net" "os" - "sync" "testing" ) -var ParseRangeTests = []struct { - s string - length int64 - r []httpRange -}{ - {"", 0, nil}, - {"foo", 0, nil}, - {"bytes=", 0, nil}, - {"bytes=5-4", 10, nil}, - {"bytes=0-2,5-4", 10, nil}, - {"bytes=0-9", 10, []httpRange{{0, 10}}}, - {"bytes=0-", 10, []httpRange{{0, 10}}}, - {"bytes=5-", 10, []httpRange{{5, 5}}}, - {"bytes=0-20", 10, []httpRange{{0, 10}}}, - {"bytes=15-,0-5", 10, nil}, - {"bytes=-5", 10, []httpRange{{5, 5}}}, - {"bytes=-15", 10, []httpRange{{0, 10}}}, - {"bytes=0-499", 10000, []httpRange{{0, 500}}}, - {"bytes=500-999", 10000, []httpRange{{500, 500}}}, - {"bytes=-500", 10000, []httpRange{{9500, 500}}}, - {"bytes=9500-", 10000, []httpRange{{9500, 500}}}, - {"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}}, - {"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}}, - {"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}}, -} - -func TestParseRange(t *testing.T) { - for _, test := range ParseRangeTests { - r := test.r - ranges, err := parseRange(test.s, test.length) - if err != nil && r != nil { - t.Errorf("parseRange(%q) returned error %q", test.s, err) - } - if len(ranges) != len(r) { - t.Errorf("len(parseRange(%q)) = %d, want %d", test.s, len(ranges), len(r)) - continue - } - for i := range r { - if ranges[i].start != r[i].start { - t.Errorf("parseRange(%q)[%d].start = %d, want %d", test.s, i, ranges[i].start, r[i].start) - } - if ranges[i].length != r[i].length { - t.Errorf("parseRange(%q)[%d].length = %d, want %d", test.s, i, ranges[i].length, r[i].length) - } - } - } -} - const ( testFile = "testdata/file" testFileLength = 11 ) -var ( - serverOnce sync.Once - serverAddr string -) - -func startServer(t *testing.T) { - serverOnce.Do(func() { - HandleFunc("/ServeFile", func(w ResponseWriter, r *Request) { - ServeFile(w, r, "testdata/file") - }) - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("listen:", err) - } - serverAddr = l.Addr().String() - go Serve(l, nil) - }) -} - var ServeFileRangeTests = []struct { start, end int r string @@ -99,7 +32,11 @@ var ServeFileRangeTests = []struct { } func TestServeFile(t *testing.T) { - startServer(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "testdata/file") + })) + defer ts.Close() + var err os.Error file, err := ioutil.ReadFile(testFile) @@ -109,8 +46,8 @@ func TestServeFile(t *testing.T) { // set up the Request (re-used for all tests) var req Request - req.Header = make(map[string]string) - if req.URL, err = ParseURL("http://" + serverAddr + "/ServeFile"); err != nil { + req.Header = make(Header) + if req.URL, err = ParseURL(ts.URL); err != nil { t.Fatal("ParseURL:", err) } req.Method = "GET" @@ -123,9 +60,9 @@ func TestServeFile(t *testing.T) { // Range tests for _, rt := range ServeFileRangeTests { - req.Header["Range"] = "bytes=" + rt.r + req.Header.Set("Range", "bytes="+rt.r) if rt.r == "" { - req.Header["Range"] = "" + req.Header["Range"] = nil } r, body := getBody(t, req) if r.StatusCode != rt.code { @@ -138,8 +75,9 @@ func TestServeFile(t *testing.T) { if rt.r == "" { h = "" } - if r.Header["Content-Range"] != h { - t.Errorf("header mismatch: range=%q: got %q, want %q", rt.r, r.Header["Content-Range"], h) + cr := r.Header.Get("Content-Range") + if cr != h { + t.Errorf("header mismatch: range=%q: got %q, want %q", rt.r, cr, h) } if !equal(body, file[rt.start:rt.end]) { t.Errorf("body mismatch: range=%q: got %q, want %q", rt.r, body, file[rt.start:rt.end]) @@ -148,7 +86,7 @@ func TestServeFile(t *testing.T) { } func getBody(t *testing.T, req Request) (*Response, []byte) { - r, err := send(&req) + r, err := DefaultClient.Do(&req) if err != nil { t.Fatal(req.URL.String(), "send:", err) } diff --git a/src/pkg/http/header.go b/src/pkg/http/header.go new file mode 100644 index 000000000..95b0f3db6 --- /dev/null +++ b/src/pkg/http/header.go @@ -0,0 +1,43 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import "net/textproto" + +// A Header represents the key-value pairs in an HTTP header. +type Header map[string][]string + +// Add adds the key, value pair to the header. +// It appends to any existing values associated with key. +func (h Header) Add(key, value string) { + textproto.MIMEHeader(h).Add(key, value) +} + +// Set sets the header entries associated with key to +// the single element value. It replaces any existing +// values associated with key. +func (h Header) Set(key, value string) { + textproto.MIMEHeader(h).Set(key, value) +} + +// Get gets the first value associated with the given key. +// If there are no values associated with the key, Get returns "". +// Get is a convenience method. For more complex queries, +// access the map directly. +func (h Header) Get(key string) string { + return textproto.MIMEHeader(h).Get(key) +} + +// Del deletes the values associated with key. +func (h Header) Del(key string) { + textproto.MIMEHeader(h).Del(key) +} + +// CanonicalHeaderKey returns the canonical format of the +// header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } diff --git a/src/pkg/http/httptest/Makefile b/src/pkg/http/httptest/Makefile new file mode 100644 index 000000000..eb35d8aec --- /dev/null +++ b/src/pkg/http/httptest/Makefile @@ -0,0 +1,12 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=http/httptest +GOFILES=\ + recorder.go\ + server.go\ + +include ../../../Make.pkg diff --git a/src/pkg/http/httptest/recorder.go b/src/pkg/http/httptest/recorder.go new file mode 100644 index 000000000..ec7bde8aa --- /dev/null +++ b/src/pkg/http/httptest/recorder.go @@ -0,0 +1,79 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The httptest package provides utilities for HTTP testing. +package httptest + +import ( + "bytes" + "http" + "os" +) + +// ResponseRecorder is an implementation of http.ResponseWriter that +// records its mutations for later inspection in tests. +type ResponseRecorder struct { + Code int // the HTTP response code from WriteHeader + Header http.Header // if non-nil, the headers to populate + Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to + Flushed bool + + FakeRemoteAddr string // the fake RemoteAddr to return, or "" for DefaultRemoteAddr + FakeUsingTLS bool // whether to return true from the UsingTLS method +} + +// NewRecorder returns an initialized ResponseRecorder. +func NewRecorder() *ResponseRecorder { + return &ResponseRecorder{ + Header: http.Header(make(map[string][]string)), + Body: new(bytes.Buffer), + } +} + +// DefaultRemoteAddr is the default remote address to return in RemoteAddr if +// an explicit DefaultRemoteAddr isn't set on ResponseRecorder. +const DefaultRemoteAddr = "1.2.3.4" + +// RemoteAddr returns the value of rw.FakeRemoteAddr, if set, else +// returns DefaultRemoteAddr. +func (rw *ResponseRecorder) RemoteAddr() string { + if rw.FakeRemoteAddr != "" { + return rw.FakeRemoteAddr + } + return DefaultRemoteAddr +} + +// UsingTLS returns the fake value in rw.FakeUsingTLS +func (rw *ResponseRecorder) UsingTLS() bool { + return rw.FakeUsingTLS +} + +// SetHeader populates rw.Header, if non-nil. +func (rw *ResponseRecorder) SetHeader(k, v string) { + if rw.Header != nil { + if v == "" { + rw.Header.Del(k) + } else { + rw.Header.Set(k, v) + } + } +} + +// Write always succeeds and writes to rw.Body, if not nil. +func (rw *ResponseRecorder) Write(buf []byte) (int, os.Error) { + if rw.Body != nil { + rw.Body.Write(buf) + } + return len(buf), nil +} + +// WriteHeader sets rw.Code. +func (rw *ResponseRecorder) WriteHeader(code int) { + rw.Code = code +} + +// Flush sets rw.Flushed to true. +func (rw *ResponseRecorder) Flush() { + rw.Flushed = true +} diff --git a/src/pkg/http/httptest/server.go b/src/pkg/http/httptest/server.go new file mode 100644 index 000000000..86c9eb435 --- /dev/null +++ b/src/pkg/http/httptest/server.go @@ -0,0 +1,42 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Implementation of Server + +package httptest + +import ( + "fmt" + "http" + "net" +) + +// A Server is an HTTP server listening on a system-chosen port on the +// local loopback interface, for use in end-to-end HTTP tests. +type Server struct { + URL string // base URL of form http://ipaddr:port with no trailing slash + Listener net.Listener +} + +// NewServer starts and returns a new Server. +// The caller should call Close when finished, to shut it down. +func NewServer(handler http.Handler) *Server { + ts := new(Server) + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err)) + } + } + ts.Listener = l + ts.URL = "http://" + l.Addr().String() + server := &http.Server{Handler: handler} + go server.Serve(l) + return ts +} + +// Close shuts down the server. +func (s *Server) Close() { + s.Listener.Close() +} diff --git a/src/pkg/http/persist.go b/src/pkg/http/persist.go index 000a4200e..53efd7c8c 100644 --- a/src/pkg/http/persist.go +++ b/src/pkg/http/persist.go @@ -25,15 +25,15 @@ var ( // i.e. requests can be read out of sync (but in the same order) while the // respective responses are sent. type ServerConn struct { + lk sync.Mutex // read-write protects the following fields c net.Conn r *bufio.Reader - clsd bool // indicates a graceful close re, we os.Error // read/write errors lastbody io.ReadCloser nread, nwritten int - pipe textproto.Pipeline pipereq map[*Request]uint - lk sync.Mutex // protected read/write to re,we + + pipe textproto.Pipeline } // NewServerConn returns a new ServerConn reading and writing c. If r is not @@ -90,15 +90,21 @@ func (sc *ServerConn) Read() (req *Request, err os.Error) { defer sc.lk.Unlock() return nil, sc.re } + if sc.r == nil { // connection closed by user in the meantime + defer sc.lk.Unlock() + return nil, os.EBADF + } + r := sc.r + lastbody := sc.lastbody + sc.lastbody = nil sc.lk.Unlock() // Make sure body is fully consumed, even if user does not call body.Close - if sc.lastbody != nil { + if lastbody != nil { // body.Close is assumed to be idempotent and multiple calls to // it should return the error that its first invokation // returned. - err = sc.lastbody.Close() - sc.lastbody = nil + err = lastbody.Close() if err != nil { sc.lk.Lock() defer sc.lk.Unlock() @@ -107,10 +113,10 @@ func (sc *ServerConn) Read() (req *Request, err os.Error) { } } - req, err = ReadRequest(sc.r) + req, err = ReadRequest(r) + sc.lk.Lock() + defer sc.lk.Unlock() if err != nil { - sc.lk.Lock() - defer sc.lk.Unlock() if err == io.ErrUnexpectedEOF { // A close from the opposing client is treated as a // graceful close, even if there was some unparse-able @@ -119,18 +125,16 @@ func (sc *ServerConn) Read() (req *Request, err os.Error) { return nil, sc.re } else { sc.re = err - return + return req, err } } sc.lastbody = req.Body sc.nread++ if req.Close { - sc.lk.Lock() - defer sc.lk.Unlock() sc.re = ErrPersistEOF return req, sc.re } - return + return req, err } // Pending returns the number of unanswered requests @@ -165,24 +169,27 @@ func (sc *ServerConn) Write(req *Request, resp *Response) os.Error { defer sc.lk.Unlock() return sc.we } - sc.lk.Unlock() + if sc.c == nil { // connection closed by user in the meantime + defer sc.lk.Unlock() + return os.EBADF + } + c := sc.c if sc.nread <= sc.nwritten { + defer sc.lk.Unlock() return os.NewError("persist server pipe count") } - if resp.Close { // After signaling a keep-alive close, any pipelined unread // requests will be lost. It is up to the user to drain them // before signaling. - sc.lk.Lock() sc.re = ErrPersistEOF - sc.lk.Unlock() } + sc.lk.Unlock() - err := resp.Write(sc.c) + err := resp.Write(c) + sc.lk.Lock() + defer sc.lk.Unlock() if err != nil { - sc.lk.Lock() - defer sc.lk.Unlock() sc.we = err return err } @@ -196,14 +203,15 @@ func (sc *ServerConn) Write(req *Request, resp *Response) os.Error { // responsible for closing the underlying connection. One must call Close to // regain control of that connection and deal with it as desired. type ClientConn struct { + lk sync.Mutex // read-write protects the following fields c net.Conn r *bufio.Reader re, we os.Error // read/write errors lastbody io.ReadCloser nread, nwritten int - pipe textproto.Pipeline pipereq map[*Request]uint - lk sync.Mutex // protects read/write to re,we,pipereq,etc. + + pipe textproto.Pipeline } // NewClientConn returns a new ClientConn reading and writing c. If r is not @@ -221,11 +229,11 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { // logic. The user should not call Close while Read or Write is in progress. func (cc *ClientConn) Close() (c net.Conn, r *bufio.Reader) { cc.lk.Lock() + defer cc.lk.Unlock() c = cc.c r = cc.r cc.c = nil cc.r = nil - cc.lk.Unlock() return } @@ -261,20 +269,22 @@ func (cc *ClientConn) Write(req *Request) (err os.Error) { defer cc.lk.Unlock() return cc.we } - cc.lk.Unlock() - + if cc.c == nil { // connection closed by user in the meantime + defer cc.lk.Unlock() + return os.EBADF + } + c := cc.c if req.Close { // We write the EOF to the write-side error, because there // still might be some pipelined reads - cc.lk.Lock() cc.we = ErrPersistEOF - cc.lk.Unlock() } + cc.lk.Unlock() - err = req.Write(cc.c) + err = req.Write(c) + cc.lk.Lock() + defer cc.lk.Unlock() if err != nil { - cc.lk.Lock() - defer cc.lk.Unlock() cc.we = err return err } @@ -316,15 +326,21 @@ func (cc *ClientConn) Read(req *Request) (resp *Response, err os.Error) { defer cc.lk.Unlock() return nil, cc.re } + if cc.r == nil { // connection closed by user in the meantime + defer cc.lk.Unlock() + return nil, os.EBADF + } + r := cc.r + lastbody := cc.lastbody + cc.lastbody = nil cc.lk.Unlock() // Make sure body is fully consumed, even if user does not call body.Close - if cc.lastbody != nil { + if lastbody != nil { // body.Close is assumed to be idempotent and multiple calls to // it should return the error that its first invokation // returned. - err = cc.lastbody.Close() - cc.lastbody = nil + err = lastbody.Close() if err != nil { cc.lk.Lock() defer cc.lk.Unlock() @@ -333,24 +349,22 @@ func (cc *ClientConn) Read(req *Request) (resp *Response, err os.Error) { } } - resp, err = ReadResponse(cc.r, req.Method) + resp, err = ReadResponse(r, req.Method) + cc.lk.Lock() + defer cc.lk.Unlock() if err != nil { - cc.lk.Lock() - defer cc.lk.Unlock() cc.re = err - return + return resp, err } cc.lastbody = resp.Body cc.nread++ if resp.Close { - cc.lk.Lock() - defer cc.lk.Unlock() cc.re = ErrPersistEOF // don't send any more requests return resp, cc.re } - return + return resp, err } // Do is convenience method that writes a request and reads a response. diff --git a/src/pkg/http/proxy_test.go b/src/pkg/http/proxy_test.go new file mode 100644 index 000000000..0f2ca458f --- /dev/null +++ b/src/pkg/http/proxy_test.go @@ -0,0 +1,45 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "os" + "testing" +) + +// TODO(mattn): +// test ProxyAuth + +var MatchNoProxyTests = []struct { + host string + match bool +}{ + {"localhost", true}, // match completely + {"barbaz.net", true}, // match as .barbaz.net + {"foobar.com:443", true}, // have a port but match + {"foofoobar.com", false}, // not match as a part of foobar.com + {"baz.com", false}, // not match as a part of barbaz.com + {"localhost.net", false}, // not match as suffix of address + {"local.localhost", false}, // not match as prefix as address + {"barbarbaz.net", false}, // not match because NO_PROXY have a '.' + {"www.foobar.com", false}, // not match because NO_PROXY is not .foobar.com +} + +func TestMatchNoProxy(t *testing.T) { + oldenv := os.Getenv("NO_PROXY") + no_proxy := "foobar.com, .barbaz.net , localhost" + os.Setenv("NO_PROXY", no_proxy) + defer os.Setenv("NO_PROXY", oldenv) + + for _, test := range MatchNoProxyTests { + if matchNoProxy(test.host) != test.match { + if test.match { + t.Errorf("matchNoProxy(%v) = %v, want %v", test.host, !test.match, test.match) + } else { + t.Errorf("not expected: '%s' shouldn't match as '%s'", test.host, no_proxy) + } + } + } +} diff --git a/src/pkg/http/range_test.go b/src/pkg/http/range_test.go new file mode 100644 index 000000000..5274a81fa --- /dev/null +++ b/src/pkg/http/range_test.go @@ -0,0 +1,57 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "testing" +) + +var ParseRangeTests = []struct { + s string + length int64 + r []httpRange +}{ + {"", 0, nil}, + {"foo", 0, nil}, + {"bytes=", 0, nil}, + {"bytes=5-4", 10, nil}, + {"bytes=0-2,5-4", 10, nil}, + {"bytes=0-9", 10, []httpRange{{0, 10}}}, + {"bytes=0-", 10, []httpRange{{0, 10}}}, + {"bytes=5-", 10, []httpRange{{5, 5}}}, + {"bytes=0-20", 10, []httpRange{{0, 10}}}, + {"bytes=15-,0-5", 10, nil}, + {"bytes=-5", 10, []httpRange{{5, 5}}}, + {"bytes=-15", 10, []httpRange{{0, 10}}}, + {"bytes=0-499", 10000, []httpRange{{0, 500}}}, + {"bytes=500-999", 10000, []httpRange{{500, 500}}}, + {"bytes=-500", 10000, []httpRange{{9500, 500}}}, + {"bytes=9500-", 10000, []httpRange{{9500, 500}}}, + {"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}}, + {"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}}, + {"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}}, +} + +func TestParseRange(t *testing.T) { + for _, test := range ParseRangeTests { + r := test.r + ranges, err := parseRange(test.s, test.length) + if err != nil && r != nil { + t.Errorf("parseRange(%q) returned error %q", test.s, err) + } + if len(ranges) != len(r) { + t.Errorf("len(parseRange(%q)) = %d, want %d", test.s, len(ranges), len(r)) + continue + } + for i := range r { + if ranges[i].start != r[i].start { + t.Errorf("parseRange(%q)[%d].start = %d, want %d", test.s, i, ranges[i].start, r[i].start) + } + if ranges[i].length != r[i].length { + t.Errorf("parseRange(%q)[%d].length = %d, want %d", test.s, i, ranges[i].length, r[i].length) + } + } + } +} diff --git a/src/pkg/http/readrequest_test.go b/src/pkg/http/readrequest_test.go index 5e1cbcbcb..19e2ff774 100644 --- a/src/pkg/http/readrequest_test.go +++ b/src/pkg/http/readrequest_test.go @@ -50,14 +50,14 @@ var reqTests = []reqTest{ Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, - Header: map[string]string{ - "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", - "Accept-Language": "en-us,en;q=0.5", - "Accept-Encoding": "gzip,deflate", - "Accept-Charset": "ISO-8859-1,utf-8;q=0.7,*;q=0.7", - "Keep-Alive": "300", - "Proxy-Connection": "keep-alive", - "Content-Length": "7", + Header: Header{ + "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}, + "Accept-Language": {"en-us,en;q=0.5"}, + "Accept-Encoding": {"gzip,deflate"}, + "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"}, + "Keep-Alive": {"300"}, + "Proxy-Connection": {"keep-alive"}, + "Content-Length": {"7"}, }, Close: false, ContentLength: 7, @@ -93,7 +93,7 @@ var reqTests = []reqTest{ Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, - Header: map[string]string{}, + Header: Header{}, Close: false, ContentLength: -1, Host: "test", diff --git a/src/pkg/http/request.go b/src/pkg/http/request.go index 04bebaaf5..d8456bab3 100644 --- a/src/pkg/http/request.go +++ b/src/pkg/http/request.go @@ -11,13 +11,13 @@ package http import ( "bufio" - "bytes" "container/vector" "fmt" "io" "io/ioutil" "mime" "mime/multipart" + "net/textproto" "os" "strconv" "strings" @@ -90,7 +90,10 @@ type Request struct { // The request parser implements this by canonicalizing the // name, making the first character and any characters // following a hyphen uppercase and the rest lowercase. - Header map[string]string + Header Header + + // Cookie records the HTTP cookies sent with the request. + Cookie []*Cookie // The message body. Body io.ReadCloser @@ -133,7 +136,7 @@ type Request struct { // Trailer maps trailer keys to values. Like for Header, if the // response has multiple trailer lines with the same key, they will be // concatenated, delimited by commas. - Trailer map[string]string + Trailer Header } // ProtoAtLeast returns whether the HTTP protocol used @@ -146,8 +149,8 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { // MultipartReader returns a MIME multipart reader if this is a // multipart/form-data POST request, else returns nil and an error. func (r *Request) MultipartReader() (multipart.Reader, os.Error) { - v, ok := r.Header["Content-Type"] - if !ok { + v := r.Header.Get("Content-Type") + if v == "" { return nil, ErrNotMultipart } d, params := mime.ParseMediaType(v) @@ -184,6 +187,19 @@ const defaultUserAgent = "Go http package" // If Body is present, Write forces "Transfer-Encoding: chunked" as a header // and then closes Body when finished sending it. func (req *Request) Write(w io.Writer) os.Error { + return req.write(w, false) +} + +// WriteProxy is like Write but writes the request in the form +// expected by an HTTP proxy. It includes the scheme and host +// name in the URI instead of using a separate Host: header line. +// If req.RawURL is non-empty, WriteProxy uses it unchanged +// instead of URL but still omits the Host: header. +func (req *Request) WriteProxy(w io.Writer) os.Error { + return req.write(w, true) +} + +func (req *Request) write(w io.Writer, usingProxy bool) os.Error { host := req.Host if host == "" { host = req.URL.Host @@ -195,12 +211,20 @@ func (req *Request) Write(w io.Writer) os.Error { if req.URL.RawQuery != "" { uri += "?" + req.URL.RawQuery } + if usingProxy { + if uri == "" || uri[0] != '/' { + uri = "/" + uri + } + uri = req.URL.Scheme + "://" + host + uri + } } fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), uri) // Header lines - fmt.Fprintf(w, "Host: %s\r\n", host) + if !usingProxy { + fmt.Fprintf(w, "Host: %s\r\n", host) + } fmt.Fprintf(w, "User-Agent: %s\r\n", valueOrDefault(req.UserAgent, defaultUserAgent)) if req.Referer != "" { fmt.Fprintf(w, "Referer: %s\r\n", req.Referer) @@ -223,11 +247,15 @@ func (req *Request) Write(w io.Writer) os.Error { // from Request, and introduce Request methods along the lines of // Response.{GetHeader,AddHeader} and string constants for "Host", // "User-Agent" and "Referer". - err = writeSortedKeyValue(w, req.Header, reqExcludeHeader) + err = writeSortedHeader(w, req.Header, reqExcludeHeader) if err != nil { return err } + if err = writeCookies(w, req.Cookie); err != nil { + return err + } + io.WriteString(w, "\r\n") // Write body and trailer @@ -277,78 +305,6 @@ func readLine(b *bufio.Reader) (s string, err os.Error) { return string(p), nil } -var colon = []byte{':'} - -// Read a key/value pair from b. -// A key/value has the form Key: Value\r\n -// and the Value can continue on multiple lines if each continuation line -// starts with a space. -func readKeyValue(b *bufio.Reader) (key, value string, err os.Error) { - line, e := readLineBytes(b) - if e != nil { - return "", "", e - } - if len(line) == 0 { - return "", "", nil - } - - // Scan first line for colon. - i := bytes.Index(line, colon) - if i < 0 { - goto Malformed - } - - key = string(line[0:i]) - if strings.Contains(key, " ") { - // Key field has space - no good. - goto Malformed - } - - // Skip initial space before value. - for i++; i < len(line); i++ { - if line[i] != ' ' { - break - } - } - value = string(line[i:]) - - // Look for extension lines, which must begin with space. - for { - c, e := b.ReadByte() - if c != ' ' { - if e != os.EOF { - b.UnreadByte() - } - break - } - - // Eat leading space. - for c == ' ' { - if c, e = b.ReadByte(); e != nil { - if e == os.EOF { - e = io.ErrUnexpectedEOF - } - return "", "", e - } - } - b.UnreadByte() - - // Read the rest of the line and add to value. - if line, e = readLineBytes(b); e != nil { - return "", "", e - } - value += " " + string(line) - - if len(value) >= maxValueLength { - return "", "", &badStringError{"value too long for key", key} - } - } - return key, value, nil - -Malformed: - return "", "", &badStringError{"malformed header line", string(line)} -} - // Convert decimal at s[i:len(s)] to integer, // returning value, string position where the digits stopped, // and whether there was a valid number (digits, not too big). @@ -367,8 +323,9 @@ func atoi(s string, i int) (n, i1 int, ok bool) { return n, i, true } -// Parse HTTP version: "HTTP/1.2" -> (1, 2, true). -func parseHTTPVersion(vers string) (int, int, bool) { +// ParseHTTPVersion parses a HTTP version string. +// "HTTP/1.0" returns (1, 0, true). +func ParseHTTPVersion(vers string) (major, minor int, ok bool) { if len(vers) < 5 || vers[0:5] != "HTTP/" { return 0, 0, false } @@ -376,7 +333,6 @@ func parseHTTPVersion(vers string) (int, int, bool) { if !ok || i >= len(vers) || vers[i] != '.' { return 0, 0, false } - var minor int minor, i, ok = atoi(vers, i+1) if !ok || i != len(vers) { return 0, 0, false @@ -384,43 +340,6 @@ func parseHTTPVersion(vers string) (int, int, bool) { return major, minor, true } -// CanonicalHeaderKey returns the canonical format of the -// HTTP header key s. The canonicalization converts the first -// letter and any letter following a hyphen to upper case; -// the rest are converted to lowercase. For example, the -// canonical key for "accept-encoding" is "Accept-Encoding". -func CanonicalHeaderKey(s string) string { - // canonicalize: first letter upper case - // and upper case after each dash. - // (Host, User-Agent, If-Modified-Since). - // HTTP headers are ASCII only, so no Unicode issues. - var a []byte - upper := true - for i := 0; i < len(s); i++ { - v := s[i] - if upper && 'a' <= v && v <= 'z' { - if a == nil { - a = []byte(s) - } - a[i] = v + 'A' - 'a' - } - if !upper && 'A' <= v && v <= 'Z' { - if a == nil { - a = []byte(s) - } - a[i] = v + 'a' - 'A' - } - upper = false - if v == '-' { - upper = true - } - } - if a != nil { - return string(a) - } - return s -} - type chunkedReader struct { r *bufio.Reader n uint64 // unread bytes in chunk @@ -486,11 +405,16 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err os.Error) { // ReadRequest reads and parses a request from b. func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { + + tp := textproto.NewReader(b) req = new(Request) // First line: GET /index.html HTTP/1.0 var s string - if s, err = readLine(b); err != nil { + if s, err = tp.ReadLine(); err != nil { + if err == os.EOF { + err = io.ErrUnexpectedEOF + } return nil, err } @@ -500,7 +424,7 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { } req.Method, req.RawURL, req.Proto = f[0], f[1], f[2] var ok bool - if req.ProtoMajor, req.ProtoMinor, ok = parseHTTPVersion(req.Proto); !ok { + if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok { return nil, &badStringError{"malformed HTTP version", req.Proto} } @@ -509,32 +433,11 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { } // Subsequent lines: Key: value. - nheader := 0 - req.Header = make(map[string]string) - for { - var key, value string - if key, value, err = readKeyValue(b); err != nil { - return nil, err - } - if key == "" { - break - } - if nheader++; nheader >= maxHeaderLines { - return nil, ErrHeaderTooLong - } - - key = CanonicalHeaderKey(key) - - // RFC 2616 says that if you send the same header key - // multiple times, it has to be semantically equivalent - // to concatenating the values separated by commas. - oldvalue, present := req.Header[key] - if present { - req.Header[key] = oldvalue + "," + value - } else { - req.Header[key] = value - } + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + return nil, err } + req.Header = Header(mimeHeader) // RFC2616: Must treat // GET /index.html HTTP/1.1 @@ -545,18 +448,18 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { // the same. In the second case, any Host line is ignored. req.Host = req.URL.Host if req.Host == "" { - req.Host = req.Header["Host"] + req.Host = req.Header.Get("Host") } - req.Header["Host"] = "", false + req.Header.Del("Host") fixPragmaCacheControl(req.Header) // Pull out useful fields as a convenience to clients. - req.Referer = req.Header["Referer"] - req.Header["Referer"] = "", false + req.Referer = req.Header.Get("Referer") + req.Header.Del("Referer") - req.UserAgent = req.Header["User-Agent"] - req.Header["User-Agent"] = "", false + req.UserAgent = req.Header.Get("User-Agent") + req.Header.Del("User-Agent") // TODO: Parse specific header values: // Accept @@ -589,6 +492,8 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { return nil, err } + req.Cookie = readCookies(req.Header) + return req, nil } @@ -642,7 +547,7 @@ func (r *Request) ParseForm() (err os.Error) { if r.Body == nil { return os.ErrorString("missing form body") } - ct := r.Header["Content-Type"] + ct := r.Header.Get("Content-Type") switch strings.Split(ct, ";", 2)[0] { case "text/plain", "application/x-www-form-urlencoded", "": b, e := ioutil.ReadAll(r.Body) @@ -677,17 +582,12 @@ func (r *Request) FormValue(key string) string { } func (r *Request) expectsContinue() bool { - expectation, ok := r.Header["Expect"] - return ok && strings.ToLower(expectation) == "100-continue" + return strings.ToLower(r.Header.Get("Expect")) == "100-continue" } func (r *Request) wantsHttp10KeepAlive() bool { if r.ProtoMajor != 1 || r.ProtoMinor != 0 { return false } - value, exists := r.Header["Connection"] - if !exists { - return false - } - return strings.Contains(strings.ToLower(value), "keep-alive") + return strings.Contains(strings.ToLower(r.Header.Get("Connection")), "keep-alive") } diff --git a/src/pkg/http/request_test.go b/src/pkg/http/request_test.go index d25e5e5e7..ae1c4e982 100644 --- a/src/pkg/http/request_test.go +++ b/src/pkg/http/request_test.go @@ -74,7 +74,9 @@ func TestQuery(t *testing.T) { func TestPostQuery(t *testing.T) { req := &Request{Method: "POST"} req.URL, _ = ParseURL("http://www.google.com/search?q=foo&q=bar&both=x") - req.Header = map[string]string{"Content-Type": "application/x-www-form-urlencoded; boo!"} + req.Header = Header{ + "Content-Type": {"application/x-www-form-urlencoded; boo!"}, + } req.Body = nopCloser{strings.NewReader("z=post&both=y")} if q := req.FormValue("q"); q != "foo" { t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) @@ -87,18 +89,18 @@ func TestPostQuery(t *testing.T) { } } -type stringMap map[string]string +type stringMap map[string][]string type parseContentTypeTest struct { contentType stringMap error bool } var parseContentTypeTests = []parseContentTypeTest{ - {contentType: stringMap{"Content-Type": "text/plain"}}, - {contentType: stringMap{"Content-Type": ""}}, - {contentType: stringMap{"Content-Type": "text/plain; boundary="}}, + {contentType: stringMap{"Content-Type": {"text/plain"}}}, + {contentType: stringMap{}}, // Non-existent keys are not placed. The value nil is illegal. + {contentType: stringMap{"Content-Type": {"text/plain; boundary="}}}, { - contentType: stringMap{"Content-Type": "application/unknown"}, + contentType: stringMap{"Content-Type": {"application/unknown"}}, error: true, }, } @@ -107,7 +109,7 @@ func TestPostContentTypeParsing(t *testing.T) { for i, test := range parseContentTypeTests { req := &Request{ Method: "POST", - Header: test.contentType, + Header: Header(test.contentType), Body: nopCloser{bytes.NewBufferString("body")}, } err := req.ParseForm() @@ -123,7 +125,7 @@ func TestPostContentTypeParsing(t *testing.T) { func TestMultipartReader(t *testing.T) { req := &Request{ Method: "POST", - Header: stringMap{"Content-Type": `multipart/form-data; boundary="foo123"`}, + Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}}, Body: nopCloser{new(bytes.Buffer)}, } multipart, err := req.MultipartReader() @@ -131,7 +133,7 @@ func TestMultipartReader(t *testing.T) { t.Errorf("expected multipart; error: %v", err) } - req.Header = stringMap{"Content-Type": "text/plain"} + req.Header = Header{"Content-Type": {"text/plain"}} multipart, err = req.MultipartReader() if multipart != nil { t.Errorf("unexpected multipart for text/plain") diff --git a/src/pkg/http/requestwrite_test.go b/src/pkg/http/requestwrite_test.go index 3ceabe4ee..03a766efd 100644 --- a/src/pkg/http/requestwrite_test.go +++ b/src/pkg/http/requestwrite_test.go @@ -10,8 +10,10 @@ import ( ) type reqWriteTest struct { - Req Request - Raw string + Req Request + Body []byte + Raw string + RawProxy string } var reqWriteTests = []reqWriteTest{ @@ -34,13 +36,13 @@ var reqWriteTests = []reqWriteTest{ Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, - Header: map[string]string{ - "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", - "Accept-Charset": "ISO-8859-1,utf-8;q=0.7,*;q=0.7", - "Accept-Encoding": "gzip,deflate", - "Accept-Language": "en-us,en;q=0.5", - "Keep-Alive": "300", - "Proxy-Connection": "keep-alive", + Header: Header{ + "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}, + "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"}, + "Accept-Encoding": {"gzip,deflate"}, + "Accept-Language": {"en-us,en;q=0.5"}, + "Keep-Alive": {"300"}, + "Proxy-Connection": {"keep-alive"}, }, Body: nil, Close: false, @@ -50,13 +52,24 @@ var reqWriteTests = []reqWriteTest{ Form: map[string][]string{}, }, + nil, + "GET http://www.techcrunch.com/ HTTP/1.1\r\n" + "Host: www.techcrunch.com\r\n" + "User-Agent: Fake\r\n" + + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" + "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" + "Accept-Encoding: gzip,deflate\r\n" + "Accept-Language: en-us,en;q=0.5\r\n" + + "Keep-Alive: 300\r\n" + + "Proxy-Connection: keep-alive\r\n\r\n", + + "GET http://www.techcrunch.com/ HTTP/1.1\r\n" + + "User-Agent: Fake\r\n" + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" + + "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" + + "Accept-Encoding: gzip,deflate\r\n" + + "Accept-Language: en-us,en;q=0.5\r\n" + "Keep-Alive: 300\r\n" + "Proxy-Connection: keep-alive\r\n\r\n", }, @@ -71,16 +84,22 @@ var reqWriteTests = []reqWriteTest{ }, ProtoMajor: 1, ProtoMinor: 1, - Header: map[string]string{}, - Body: nopCloser{bytes.NewBufferString("abcdef")}, + Header: Header{}, TransferEncoding: []string{"chunked"}, }, + []byte("abcdef"), + "GET /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + "6\r\nabcdef\r\n0\r\n\r\n", + + "GET http://www.google.com/search HTTP/1.1\r\n" + + "User-Agent: Go http package\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "6\r\nabcdef\r\n0\r\n\r\n", }, // HTTP/1.1 POST => chunked coding; body; empty trailer { @@ -93,18 +112,25 @@ var reqWriteTests = []reqWriteTest{ }, ProtoMajor: 1, ProtoMinor: 1, - Header: map[string]string{}, + Header: Header{}, Close: true, - Body: nopCloser{bytes.NewBufferString("abcdef")}, TransferEncoding: []string{"chunked"}, }, + []byte("abcdef"), + "POST /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "Connection: close\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + "6\r\nabcdef\r\n0\r\n\r\n", + + "POST http://www.google.com/search HTTP/1.1\r\n" + + "User-Agent: Go http package\r\n" + + "Connection: close\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "6\r\nabcdef\r\n0\r\n\r\n", }, // default to HTTP/1.1 { @@ -114,16 +140,26 @@ var reqWriteTests = []reqWriteTest{ Host: "www.google.com", }, + nil, + "GET /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "\r\n", + + // Looks weird but RawURL overrides what WriteProxy would choose. + "GET /search HTTP/1.1\r\n" + + "User-Agent: Go http package\r\n" + + "\r\n", }, } func TestRequestWrite(t *testing.T) { for i := range reqWriteTests { tt := &reqWriteTests[i] + if tt.Body != nil { + tt.Req.Body = nopCloser{bytes.NewBuffer(tt.Body)} + } var braw bytes.Buffer err := tt.Req.Write(&braw) if err != nil { @@ -135,5 +171,20 @@ func TestRequestWrite(t *testing.T) { t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.Raw, sraw) continue } + + if tt.Body != nil { + tt.Req.Body = nopCloser{bytes.NewBuffer(tt.Body)} + } + var praw bytes.Buffer + err = tt.Req.WriteProxy(&praw) + if err != nil { + t.Errorf("error writing #%d: %s", i, err) + continue + } + sraw = praw.String() + if sraw != tt.RawProxy { + t.Errorf("Test Proxy %d, expecting:\n%s\nGot:\n%s\n", i, tt.RawProxy, sraw) + continue + } } } diff --git a/src/pkg/http/response.go b/src/pkg/http/response.go index a24726110..3d77c5555 100644 --- a/src/pkg/http/response.go +++ b/src/pkg/http/response.go @@ -10,6 +10,7 @@ import ( "bufio" "fmt" "io" + "net/textproto" "os" "sort" "strconv" @@ -43,7 +44,10 @@ type Response struct { // omitted from Header. // // Keys in the map are canonicalized (see CanonicalHeaderKey). - Header map[string]string + Header Header + + // SetCookie records the Set-Cookie requests sent with the response. + SetCookie []*Cookie // Body represents the response body. Body io.ReadCloser @@ -63,10 +67,9 @@ type Response struct { // ReadResponse nor Response.Write ever closes a connection. Close bool - // Trailer maps trailer keys to values. Like for Header, if the - // response has multiple trailer lines with the same key, they will be - // concatenated, delimited by commas. - Trailer map[string]string + // Trailer maps trailer keys to values, in the same + // format as the header. + Trailer Header } // ReadResponse reads and returns an HTTP response from r. The RequestMethod @@ -76,13 +79,17 @@ type Response struct { // key/value pairs included in the response trailer. func ReadResponse(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) { + tp := textproto.NewReader(r) resp = new(Response) resp.RequestMethod = strings.ToUpper(requestMethod) // Parse the first line of the response. - line, err := readLine(r) + line, err := tp.ReadLine() if err != nil { + if err == os.EOF { + err = io.ErrUnexpectedEOF + } return nil, err } f := strings.Split(line, " ", 3) @@ -101,26 +108,16 @@ func ReadResponse(r *bufio.Reader, requestMethod string) (resp *Response, err os resp.Proto = f[0] var ok bool - if resp.ProtoMajor, resp.ProtoMinor, ok = parseHTTPVersion(resp.Proto); !ok { + if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok { return nil, &badStringError{"malformed HTTP version", resp.Proto} } // Parse the response headers. - nheader := 0 - resp.Header = make(map[string]string) - for { - key, value, err := readKeyValue(r) - if err != nil { - return nil, err - } - if key == "" { - break // end of response header - } - if nheader++; nheader >= maxHeaderLines { - return nil, ErrHeaderTooLong - } - resp.AddHeader(key, value) + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + return nil, err } + resp.Header = Header(mimeHeader) fixPragmaCacheControl(resp.Header) @@ -129,6 +126,8 @@ func ReadResponse(r *bufio.Reader, requestMethod string) (resp *Response, err os return nil, err } + resp.SetCookie = readSetCookies(resp.Header) + return resp, nil } @@ -136,34 +135,14 @@ func ReadResponse(r *bufio.Reader, requestMethod string) (resp *Response, err os // Pragma: no-cache // like // Cache-Control: no-cache -func fixPragmaCacheControl(header map[string]string) { - if header["Pragma"] == "no-cache" { +func fixPragmaCacheControl(header Header) { + if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { if _, presentcc := header["Cache-Control"]; !presentcc { - header["Cache-Control"] = "no-cache" + header["Cache-Control"] = []string{"no-cache"} } } } -// AddHeader adds a value under the given key. Keys are not case sensitive. -func (r *Response) AddHeader(key, value string) { - key = CanonicalHeaderKey(key) - - oldValues, oldValuesPresent := r.Header[key] - if oldValuesPresent { - r.Header[key] = oldValues + "," + value - } else { - r.Header[key] = value - } -} - -// GetHeader returns the value of the response header with the given key. -// If there were multiple headers with this key, their values are concatenated, -// with a comma delimiter. If there were no response headers with the given -// key, GetHeader returns an empty string. Keys are not case sensitive. -func (r *Response) GetHeader(key string) (value string) { - return r.Header[CanonicalHeaderKey(key)] -} - // ProtoAtLeast returns whether the HTTP protocol used // in the response is at least major.minor. func (r *Response) ProtoAtLeast(major, minor int) bool { @@ -213,11 +192,15 @@ func (resp *Response) Write(w io.Writer) os.Error { } // Rest of header - err = writeSortedKeyValue(w, resp.Header, respExcludeHeader) + err = writeSortedHeader(w, resp.Header, respExcludeHeader) if err != nil { return err } + if err = writeSetCookies(w, resp.SetCookie); err != nil { + return err + } + // End-of-header io.WriteString(w, "\r\n") @@ -231,20 +214,19 @@ func (resp *Response) Write(w io.Writer) os.Error { return nil } -func writeSortedKeyValue(w io.Writer, kvm map[string]string, exclude map[string]bool) os.Error { - kva := make([]string, len(kvm)) - i := 0 - for k, v := range kvm { +func writeSortedHeader(w io.Writer, h Header, exclude map[string]bool) os.Error { + keys := make([]string, 0, len(h)) + for k := range h { if !exclude[k] { - kva[i] = fmt.Sprint(k + ": " + v + "\r\n") - i++ + keys = append(keys, k) } } - kva = kva[0:i] - sort.SortStrings(kva) - for _, l := range kva { - if _, err := io.WriteString(w, l); err != nil { - return err + sort.SortStrings(keys) + for _, k := range keys { + for _, v := range h[k] { + if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { + return err + } } } return nil diff --git a/src/pkg/http/response_test.go b/src/pkg/http/response_test.go index 11bfdd08c..bf63ccb9e 100644 --- a/src/pkg/http/response_test.go +++ b/src/pkg/http/response_test.go @@ -34,8 +34,8 @@ var respTests = []respTest{ ProtoMajor: 1, ProtoMinor: 0, RequestMethod: "GET", - Header: map[string]string{ - "Connection": "close", // TODO(rsc): Delete? + Header: Header{ + "Connection": {"close"}, // TODO(rsc): Delete? }, Close: true, ContentLength: -1, @@ -100,9 +100,9 @@ var respTests = []respTest{ ProtoMajor: 1, ProtoMinor: 0, RequestMethod: "GET", - Header: map[string]string{ - "Connection": "close", // TODO(rsc): Delete? - "Content-Length": "10", // TODO(rsc): Delete? + Header: Header{ + "Connection": {"close"}, // TODO(rsc): Delete? + "Content-Length": {"10"}, // TODO(rsc): Delete? }, Close: true, ContentLength: 10, @@ -128,7 +128,7 @@ var respTests = []respTest{ ProtoMajor: 1, ProtoMinor: 0, RequestMethod: "GET", - Header: map[string]string{}, + Header: Header{}, Close: true, ContentLength: -1, TransferEncoding: []string{"chunked"}, @@ -155,7 +155,7 @@ var respTests = []respTest{ ProtoMajor: 1, ProtoMinor: 0, RequestMethod: "GET", - Header: map[string]string{}, + Header: Header{}, Close: true, ContentLength: -1, // TODO(rsc): Fix? TransferEncoding: []string{"chunked"}, @@ -175,7 +175,7 @@ var respTests = []respTest{ ProtoMajor: 1, ProtoMinor: 0, RequestMethod: "GET", - Header: map[string]string{}, + Header: Header{}, Close: true, ContentLength: -1, }, @@ -194,7 +194,7 @@ var respTests = []respTest{ ProtoMajor: 1, ProtoMinor: 0, RequestMethod: "GET", - Header: map[string]string{}, + Header: Header{}, Close: true, ContentLength: -1, }, diff --git a/src/pkg/http/responsewrite_test.go b/src/pkg/http/responsewrite_test.go index 9f10be562..228ed5f7d 100644 --- a/src/pkg/http/responsewrite_test.go +++ b/src/pkg/http/responsewrite_test.go @@ -22,7 +22,7 @@ var respWriteTests = []respWriteTest{ ProtoMajor: 1, ProtoMinor: 0, RequestMethod: "GET", - Header: map[string]string{}, + Header: Header{}, Body: nopCloser{bytes.NewBufferString("abcdef")}, ContentLength: 6, }, @@ -38,7 +38,7 @@ var respWriteTests = []respWriteTest{ ProtoMajor: 1, ProtoMinor: 0, RequestMethod: "GET", - Header: map[string]string{}, + Header: Header{}, Body: nopCloser{bytes.NewBufferString("abcdef")}, ContentLength: -1, }, @@ -53,7 +53,7 @@ var respWriteTests = []respWriteTest{ ProtoMajor: 1, ProtoMinor: 1, RequestMethod: "GET", - Header: map[string]string{}, + Header: Header{}, Body: nopCloser{bytes.NewBufferString("abcdef")}, ContentLength: 6, TransferEncoding: []string{"chunked"}, diff --git a/src/pkg/http/serve_test.go b/src/pkg/http/serve_test.go index 5594d512a..86d64bdbb 100644 --- a/src/pkg/http/serve_test.go +++ b/src/pkg/http/serve_test.go @@ -4,16 +4,18 @@ // End-to-end serving tests -package http +package http_test import ( "bufio" "bytes" "fmt" - "io" + . "http" + "http/httptest" "io/ioutil" "os" "net" + "strings" "testing" "time" ) @@ -169,13 +171,10 @@ func TestHostHandlers(t *testing.T) { for _, h := range handlers { Handle(h.pattern, stringHandler(h.msg)) } - l, err := net.Listen("tcp", "127.0.0.1:0") // any port - if err != nil { - t.Fatal(err) - } - defer l.Close() - go Serve(l, nil) - conn, err := net.Dial("tcp", "", l.Addr().String()) + ts := httptest.NewServer(nil) + defer ts.Close() + + conn, err := net.Dial("tcp", "", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) } @@ -197,53 +196,13 @@ func TestHostHandlers(t *testing.T) { t.Errorf("reading response: %v", err) continue } - s := r.Header["Result"] + s := r.Header.Get("Result") if s != vt.expected { t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected) } } } -type responseWriterMethodCall struct { - method string - headerKey, headerValue string // if method == "SetHeader" - bytesWritten []byte // if method == "Write" - responseCode int // if method == "WriteHeader" -} - -type recordingResponseWriter struct { - log []*responseWriterMethodCall -} - -func (rw *recordingResponseWriter) RemoteAddr() string { - return "1.2.3.4" -} - -func (rw *recordingResponseWriter) UsingTLS() bool { - return false -} - -func (rw *recordingResponseWriter) SetHeader(k, v string) { - rw.log = append(rw.log, &responseWriterMethodCall{method: "SetHeader", headerKey: k, headerValue: v}) -} - -func (rw *recordingResponseWriter) Write(buf []byte) (int, os.Error) { - rw.log = append(rw.log, &responseWriterMethodCall{method: "Write", bytesWritten: buf}) - return len(buf), nil -} - -func (rw *recordingResponseWriter) WriteHeader(code int) { - rw.log = append(rw.log, &responseWriterMethodCall{method: "WriteHeader", responseCode: code}) -} - -func (rw *recordingResponseWriter) Flush() { - rw.log = append(rw.log, &responseWriterMethodCall{method: "Flush"}) -} - -func (rw *recordingResponseWriter) Hijack() (io.ReadWriteCloser, *bufio.ReadWriter, os.Error) { - panic("Not supported") -} - // Tests for http://code.google.com/p/go/issues/detail?id=900 func TestMuxRedirectLeadingSlashes(t *testing.T) { paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"} @@ -253,35 +212,17 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { t.Errorf("%s", err) } mux := NewServeMux() - resp := new(recordingResponseWriter) - resp.log = make([]*responseWriterMethodCall, 0) + resp := httptest.NewRecorder() mux.ServeHTTP(resp, req) - dumpLog := func() { - t.Logf("For path %q:", path) - for _, call := range resp.log { - t.Logf("Got call: %s, header=%s, value=%s, buf=%q, code=%d", call.method, - call.headerKey, call.headerValue, call.bytesWritten, call.responseCode) - } - } - - if len(resp.log) != 2 { - dumpLog() - t.Errorf("expected 2 calls to response writer; got %d", len(resp.log)) + if loc, expected := resp.Header.Get("Location"), "/foo.txt"; loc != expected { + t.Errorf("Expected Location header set to %q; got %q", expected, loc) return } - if resp.log[0].method != "SetHeader" || - resp.log[0].headerKey != "Location" || resp.log[0].headerValue != "/foo.txt" { - dumpLog() - t.Errorf("Expected SetHeader of Location to /foo.txt") - return - } - - if resp.log[1].method != "WriteHeader" || resp.log[1].responseCode != StatusMovedPermanently { - dumpLog() - t.Errorf("Expected WriteHeader of StatusMovedPermanently") + if code, expected := resp.Code, StatusMovedPermanently; code != expected { + t.Errorf("Expected response code of StatusMovedPermanently; got %d", code) return } } @@ -349,3 +290,78 @@ func TestServerTimeouts(t *testing.T) { l.Close() } + +// TestIdentityResponse verifies that a handler can unset +func TestIdentityResponse(t *testing.T) { + handler := HandlerFunc(func(rw ResponseWriter, req *Request) { + rw.SetHeader("Content-Length", "3") + rw.SetHeader("Transfer-Encoding", req.FormValue("te")) + switch { + case req.FormValue("overwrite") == "1": + _, err := rw.Write([]byte("foo TOO LONG")) + if err != ErrContentLength { + t.Errorf("expected ErrContentLength; got %v", err) + } + case req.FormValue("underwrite") == "1": + rw.SetHeader("Content-Length", "500") + rw.Write([]byte("too short")) + default: + rw.Write([]byte("foo")) + } + }) + + ts := httptest.NewServer(handler) + defer ts.Close() + + // Note: this relies on the assumption (which is true) that + // Get sends HTTP/1.1 or greater requests. Otherwise the + // server wouldn't have the choice to send back chunked + // responses. + for _, te := range []string{"", "identity"} { + url := ts.URL + "/?te=" + te + res, _, err := Get(url) + if err != nil { + t.Fatalf("error with Get of %s: %v", url, err) + } + if cl, expected := res.ContentLength, int64(3); cl != expected { + t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl) + } + if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected { + t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl) + } + if tl, expected := len(res.TransferEncoding), 0; tl != expected { + t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)", + url, expected, tl, res.TransferEncoding) + } + } + + // Verify that ErrContentLength is returned + url := ts.URL + "/?overwrite=1" + _, _, err := Get(url) + if err != nil { + t.Fatalf("error with Get of %s: %v", url, err) + } + + // Verify that the connection is closed when the declared Content-Length + // is larger than what the handler wrote. + conn, err := net.Dial("tcp", "", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n")) + if err != nil { + t.Fatalf("error writing: %v", err) + } + // The next ReadAll will hang for a failing test, so use a Timer instead + // to fail more traditionally + timer := time.AfterFunc(2e9, func() { + t.Fatalf("Timeout expired in ReadAll.") + }) + defer timer.Stop() + got, _ := ioutil.ReadAll(conn) + expectedSuffix := "\r\n\r\ntoo short" + if !strings.HasSuffix(string(got), expectedSuffix) { + t.Fatalf("Expected output to end with %q; got response body %q", + expectedSuffix, string(got)) + } +} diff --git a/src/pkg/http/server.go b/src/pkg/http/server.go index 0be270ad3..5d623e696 100644 --- a/src/pkg/http/server.go +++ b/src/pkg/http/server.go @@ -31,6 +31,7 @@ var ( ErrWriteAfterFlush = os.NewError("Conn.Write called after Flush") ErrBodyNotAllowed = os.NewError("http: response status code does not allow body") ErrHijacked = os.NewError("Conn has been hijacked") + ErrContentLength = os.NewError("Conn.Write wrote more than the declared Content-Length") ) // Objects implementing the Handler interface can be @@ -60,10 +61,10 @@ type ResponseWriter interface { // // Content-Type: text/html; charset=utf-8 // - // being sent. UTF-8 encoded HTML is the default setting for + // being sent. UTF-8 encoded HTML is the default setting for // Content-Type in this library, so users need not make that - // particular call. Calls to SetHeader after WriteHeader (or Write) - // are ignored. + // particular call. Calls to SetHeader after WriteHeader (or Write) + // are ignored. An empty value removes the header if previously set. SetHeader(string, string) // Write writes the data to the connection as part of an HTTP reply. @@ -80,7 +81,10 @@ type ResponseWriter interface { // Flush sends any buffered data to the client. Flush() +} +// A Hijacker is an HTTP request which be taken over by an HTTP handler. +type Hijacker interface { // Hijack lets the caller take over the connection. // After a call to Hijack(), the HTTP server library // will not do anything else with the connection. @@ -108,6 +112,7 @@ type response struct { wroteContinue bool // 100 Continue response was written header map[string]string // reply header parameters written int64 // number of bytes written in body + contentLength int64 // explicitly-declared Content-Length; or -1 status int // status code passed to WriteHeader // close connection after this reply. set on request and @@ -170,33 +175,13 @@ func (c *conn) readRequest() (w *response, err os.Error) { w.conn = c w.req = req w.header = make(map[string]string) + w.contentLength = -1 // Expect 100 Continue support if req.expectsContinue() && req.ProtoAtLeast(1, 1) { // Wrap the Body reader with one that replies on the connection req.Body = &expectContinueReader{readCloser: req.Body, resp: w} } - - // Default output is HTML encoded in UTF-8. - w.SetHeader("Content-Type", "text/html; charset=utf-8") - w.SetHeader("Date", time.UTC().Format(TimeFormat)) - - if req.Method == "HEAD" { - // do nothing - } else if req.ProtoAtLeast(1, 1) { - // HTTP/1.1 or greater: use chunked transfer encoding - // to avoid closing the connection at EOF. - w.chunking = true - w.SetHeader("Transfer-Encoding", "chunked") - } else { - // HTTP version < 1.1: cannot do chunked transfer - // encoding, so signal EOF by closing connection. - // Will be overridden if the HTTP handler ends up - // writing a Content-Length and the client requested - // "Connection: keep-alive" - w.closeAfterReply = true - } - return w, nil } @@ -209,7 +194,10 @@ func (w *response) UsingTLS() bool { func (w *response) RemoteAddr() string { return w.conn.remoteAddr } // SetHeader implements the ResponseWriter.SetHeader method -func (w *response) SetHeader(hdr, val string) { w.header[CanonicalHeaderKey(hdr)] = val } +// An empty value removes the header from the map. +func (w *response) SetHeader(hdr, val string) { + w.header[CanonicalHeaderKey(hdr)] = val, val != "" +} // WriteHeader implements the ResponseWriter.WriteHeader method func (w *response) WriteHeader(code int) { @@ -225,13 +213,86 @@ func (w *response) WriteHeader(code int) { w.status = code if code == StatusNotModified { // Must not have body. - w.header["Content-Type"] = "", false - w.header["Transfer-Encoding"] = "", false + for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { + if w.header[header] != "" { + // TODO: return an error if WriteHeader gets a return parameter + // or set a flag on w to make future Writes() write an error page? + // for now just log and drop the header. + log.Printf("http: StatusNotModified response with header %q defined", header) + w.header[header] = "", false + } + } + } else { + // Default output is HTML encoded in UTF-8. + if w.header["Content-Type"] == "" { + w.SetHeader("Content-Type", "text/html; charset=utf-8") + } + } + + if w.header["Date"] == "" { + w.SetHeader("Date", time.UTC().Format(TimeFormat)) + } + + // Check for a explicit (and valid) Content-Length header. + var hasCL bool + var contentLength int64 + if clenStr, ok := w.header["Content-Length"]; ok { + var err os.Error + contentLength, err = strconv.Atoi64(clenStr) + if err == nil { + hasCL = true + } else { + log.Printf("http: invalid Content-Length of %q sent", clenStr) + w.SetHeader("Content-Length", "") + } + } + + te, hasTE := w.header["Transfer-Encoding"] + if hasCL && hasTE && te != "identity" { + // TODO: return an error if WriteHeader gets a return parameter + // For now just ignore the Content-Length. + log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", + te, contentLength) + w.SetHeader("Content-Length", "") + hasCL = false + } + + if w.req.Method == "HEAD" { + // do nothing + } else if hasCL { w.chunking = false + w.contentLength = contentLength + w.SetHeader("Transfer-Encoding", "") + } else if w.req.ProtoAtLeast(1, 1) { + // HTTP/1.1 or greater: use chunked transfer encoding + // to avoid closing the connection at EOF. + // TODO: this blows away any custom or stacked Transfer-Encoding they + // might have set. Deal with that as need arises once we have a valid + // use case. + w.chunking = true + w.SetHeader("Transfer-Encoding", "chunked") + } else { + // HTTP version < 1.1: cannot do chunked transfer + // encoding and we don't know the Content-Length so + // signal EOF by closing connection. + w.closeAfterReply = true + w.chunking = false // redundant + w.SetHeader("Transfer-Encoding", "") // in case already set } + + if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { + _, connectionHeaderSet := w.header["Connection"] + if !connectionHeaderSet { + w.SetHeader("Connection", "keep-alive") + } + } else if !w.req.ProtoAtLeast(1, 1) { + // Client did not ask to keep connection alive. + w.closeAfterReply = true + } + // Cannot use Content-Length with non-identity Transfer-Encoding. if w.chunking { - w.header["Content-Length"] = "", false + w.SetHeader("Content-Length", "") } if !w.req.ProtoAtLeast(1, 0) { return @@ -259,15 +320,6 @@ func (w *response) Write(data []byte) (n int, err os.Error) { return 0, ErrHijacked } if !w.wroteHeader { - if w.req.wantsHttp10KeepAlive() { - _, hasLength := w.header["Content-Length"] - if hasLength { - _, connectionHeaderSet := w.header["Connection"] - if !connectionHeaderSet { - w.header["Connection"] = "keep-alive" - } - } - } w.WriteHeader(StatusOK) } if len(data) == 0 { @@ -280,6 +332,9 @@ func (w *response) Write(data []byte) (n int, err os.Error) { } w.written += int64(len(data)) // ignoring errors, for errorKludge + if w.contentLength != -1 && w.written > w.contentLength { + return 0, ErrContentLength + } // TODO(rsc): if chunking happened after the buffering, // then there would be fewer chunk headers. @@ -369,6 +424,11 @@ func (w *response) finishRequest() { } w.conn.buf.Flush() w.req.Body.Close() + + if w.contentLength != -1 && w.contentLength != w.written { + // Did not write enough. Avoid getting out of sync. + w.closeAfterReply = true + } } // Flush implements the ResponseWriter.Flush method. @@ -657,10 +717,12 @@ func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Re // Handle registers the handler for the given pattern // in the DefaultServeMux. +// The documentation for ServeMux explains how patterns are matched. func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) } // HandleFunc registers the handler function for the given pattern // in the DefaultServeMux. +// The documentation for ServeMux explains how patterns are matched. func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { DefaultServeMux.HandleFunc(pattern, handler) } diff --git a/src/pkg/http/transfer.go b/src/pkg/http/transfer.go index f80f0ac63..996e28973 100644 --- a/src/pkg/http/transfer.go +++ b/src/pkg/http/transfer.go @@ -21,7 +21,7 @@ type transferWriter struct { ContentLength int64 Close bool TransferEncoding []string - Trailer map[string]string + Trailer Header } func newTransferWriter(r interface{}) (t *transferWriter, err os.Error) { @@ -159,7 +159,7 @@ func (t *transferWriter) WriteBody(w io.Writer) (err os.Error) { type transferReader struct { // Input - Header map[string]string + Header Header StatusCode int RequestMethod string ProtoMajor int @@ -169,7 +169,7 @@ type transferReader struct { ContentLength int64 TransferEncoding []string Close bool - Trailer map[string]string + Trailer Header } // bodyAllowedForStatus returns whether a given response status code @@ -289,14 +289,14 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err os.Error) { func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } // Sanitize transfer encoding -func fixTransferEncoding(header map[string]string) ([]string, os.Error) { +func fixTransferEncoding(header Header) ([]string, os.Error) { raw, present := header["Transfer-Encoding"] if !present { return nil, nil } - header["Transfer-Encoding"] = "", false - encodings := strings.Split(raw, ",", -1) + header["Transfer-Encoding"] = nil, false + encodings := strings.Split(raw[0], ",", -1) te := make([]string, 0, len(encodings)) // TODO: Even though we only support "identity" and "chunked" // encodings, the loop below is designed with foresight. One @@ -321,7 +321,7 @@ func fixTransferEncoding(header map[string]string) ([]string, os.Error) { // Chunked encoding trumps Content-Length. See RFC 2616 // Section 4.4. Currently len(te) > 0 implies chunked // encoding. - header["Content-Length"] = "", false + header["Content-Length"] = nil, false return te, nil } @@ -331,7 +331,7 @@ func fixTransferEncoding(header map[string]string) ([]string, os.Error) { // Determine the expected body length, using RFC 2616 Section 4.4. This // function is not a method, because ultimately it should be shared by // ReadResponse and ReadRequest. -func fixLength(status int, requestMethod string, header map[string]string, te []string) (int64, os.Error) { +func fixLength(status int, requestMethod string, header Header, te []string) (int64, os.Error) { // Logic based on response type or status if noBodyExpected(requestMethod) { @@ -351,23 +351,21 @@ func fixLength(status int, requestMethod string, header map[string]string, te [] } // Logic based on Content-Length - if cl, present := header["Content-Length"]; present { - cl = strings.TrimSpace(cl) - if cl != "" { - n, err := strconv.Atoi64(cl) - if err != nil || n < 0 { - return -1, &badStringError{"bad Content-Length", cl} - } - return n, nil - } else { - header["Content-Length"] = "", false + cl := strings.TrimSpace(header.Get("Content-Length")) + if cl != "" { + n, err := strconv.Atoi64(cl) + if err != nil || n < 0 { + return -1, &badStringError{"bad Content-Length", cl} } + return n, nil + } else { + header.Del("Content-Length") } // Logic based on media type. The purpose of the following code is just // to detect whether the unsupported "multipart/byteranges" is being // used. A proper Content-Type parser is needed in the future. - if strings.Contains(strings.ToLower(header["Content-Type"]), "multipart/byteranges") { + if strings.Contains(strings.ToLower(header.Get("Content-Type")), "multipart/byteranges") { return -1, ErrNotSupported } @@ -378,24 +376,19 @@ func fixLength(status int, requestMethod string, header map[string]string, te [] // Determine whether to hang up after sending a request and body, or // receiving a response and body // 'header' is the request headers -func shouldClose(major, minor int, header map[string]string) bool { +func shouldClose(major, minor int, header Header) bool { if major < 1 { return true } else if major == 1 && minor == 0 { - v, present := header["Connection"] - if !present { - return true - } - v = strings.ToLower(v) - if !strings.Contains(v, "keep-alive") { + if !strings.Contains(strings.ToLower(header.Get("Connection")), "keep-alive") { return true } return false - } else if v, present := header["Connection"]; present { + } else { // TODO: Should split on commas, toss surrounding white space, // and check each field. - if v == "close" { - header["Connection"] = "", false + if strings.ToLower(header.Get("Connection")) == "close" { + header.Del("Connection") return true } } @@ -403,14 +396,14 @@ func shouldClose(major, minor int, header map[string]string) bool { } // Parse the trailer header -func fixTrailer(header map[string]string, te []string) (map[string]string, os.Error) { - raw, present := header["Trailer"] - if !present { +func fixTrailer(header Header, te []string) (Header, os.Error) { + raw := header.Get("Trailer") + if raw == "" { return nil, nil } - header["Trailer"] = "", false - trailer := make(map[string]string) + header.Del("Trailer") + trailer := make(Header) keys := strings.Split(raw, ",", -1) for _, key := range keys { key = CanonicalHeaderKey(strings.TrimSpace(key)) @@ -418,7 +411,7 @@ func fixTrailer(header map[string]string, te []string) (map[string]string, os.Er case "Transfer-Encoding", "Trailer", "Content-Length": return nil, &badStringError{"bad trailer key", key} } - trailer[key] = "" + trailer.Del(key) } if len(trailer) == 0 { return nil, nil diff --git a/src/pkg/http/transport.go b/src/pkg/http/transport.go new file mode 100644 index 000000000..78d316a55 --- /dev/null +++ b/src/pkg/http/transport.go @@ -0,0 +1,151 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bufio" + "crypto/tls" + "encoding/base64" + "fmt" + "net" + "os" + "strings" + "sync" +) + +// DefaultTransport is the default implementation of Transport and is +// used by DefaultClient. It establishes a new network connection for +// each call to Do and uses HTTP proxies as directed by the +// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy) +// environment variables. +var DefaultTransport Transport = &transport{} + +// transport implements Tranport for the default case, using TCP +// connections to either the host or a proxy, serving http or https +// schemes. In the future this may become public and support options +// on keep-alive connection duration, pipelining controls, etc. For +// now this is simply a port of the old Go code client code to the +// Transport interface. +type transport struct { + // TODO: keep-alives, pipelining, etc using a map from + // scheme/host to a connection. Something like: + l sync.Mutex + hostConn map[string]*ClientConn +} + +func (ct *transport) Do(req *Request) (resp *Response, err os.Error) { + if req.URL.Scheme != "http" && req.URL.Scheme != "https" { + return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} + } + + addr := req.URL.Host + if !hasPort(addr) { + addr += ":" + req.URL.Scheme + } + + var proxyURL *URL + proxyAuth := "" + proxy := "" + if !matchNoProxy(addr) { + proxy = os.Getenv("HTTP_PROXY") + if proxy == "" { + proxy = os.Getenv("http_proxy") + } + } + + var write = (*Request).Write + + if proxy != "" { + write = (*Request).WriteProxy + proxyURL, err = ParseRequestURL(proxy) + if err != nil { + return nil, os.ErrorString("invalid proxy address") + } + if proxyURL.Host == "" { + proxyURL, err = ParseRequestURL("http://" + proxy) + if err != nil { + return nil, os.ErrorString("invalid proxy address") + } + } + addr = proxyURL.Host + proxyInfo := proxyURL.RawUserinfo + if proxyInfo != "" { + enc := base64.URLEncoding + encoded := make([]byte, enc.EncodedLen(len(proxyInfo))) + enc.Encode(encoded, []byte(proxyInfo)) + proxyAuth = "Basic " + string(encoded) + } + } + + // Connect to server or proxy + conn, err := net.Dial("tcp", "", addr) + if err != nil { + return nil, err + } + + if req.URL.Scheme == "http" { + // Include proxy http header if needed. + if proxyAuth != "" { + req.Header.Set("Proxy-Authorization", proxyAuth) + } + } else { // https + if proxyURL != nil { + // Ask proxy for direct connection to server. + // addr defaults above to ":https" but we need to use numbers + addr = req.URL.Host + if !hasPort(addr) { + addr += ":443" + } + fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\n", addr) + fmt.Fprintf(conn, "Host: %s\r\n", addr) + if proxyAuth != "" { + fmt.Fprintf(conn, "Proxy-Authorization: %s\r\n", proxyAuth) + } + fmt.Fprintf(conn, "\r\n") + + // Read response. + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(conn) + resp, err := ReadResponse(br, "CONNECT") + if err != nil { + return nil, err + } + if resp.StatusCode != 200 { + f := strings.Split(resp.Status, " ", 2) + return nil, os.ErrorString(f[1]) + } + } + + // Initiate TLS and check remote host name against certificate. + conn = tls.Client(conn, nil) + if err = conn.(*tls.Conn).Handshake(); err != nil { + return nil, err + } + h := req.URL.Host + if hasPort(h) { + h = h[:strings.LastIndex(h, ":")] + } + if err = conn.(*tls.Conn).VerifyHostname(h); err != nil { + return nil, err + } + } + + err = write(req, conn) + if err != nil { + conn.Close() + return nil, err + } + + reader := bufio.NewReader(conn) + resp, err = ReadResponse(reader, req.Method) + if err != nil { + conn.Close() + return nil, err + } + + resp.Body = readClose{resp.Body, conn} + return +} |