diff options
Diffstat (limited to 'src/pkg/http')
30 files changed, 982 insertions, 456 deletions
diff --git a/src/pkg/http/cgi/child.go b/src/pkg/http/cgi/child.go index e1ad7ad32..8b74d7054 100644 --- a/src/pkg/http/cgi/child.go +++ b/src/pkg/http/cgi/child.go @@ -45,13 +45,6 @@ func envMap(env []string) map[string]string { return m } -// These environment variables are manually copied into Request -var skipHeader = map[string]bool{ - "HTTP_HOST": true, - "HTTP_REFERER": true, - "HTTP_USER_AGENT": true, -} - // RequestFromMap creates an http.Request from CGI variables. // The returned Request's Body field is not populated. func RequestFromMap(params map[string]string) (*http.Request, os.Error) { @@ -73,8 +66,6 @@ func RequestFromMap(params map[string]string) (*http.Request, os.Error) { r.Header = http.Header{} r.Host = params["HTTP_HOST"] - r.Referer = params["HTTP_REFERER"] - r.UserAgent = params["HTTP_USER_AGENT"] if lenstr := params["CONTENT_LENGTH"]; lenstr != "" { clen, err := strconv.Atoi64(lenstr) @@ -90,7 +81,7 @@ func RequestFromMap(params map[string]string) (*http.Request, os.Error) { // Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers for k, v := range params { - if !strings.HasPrefix(k, "HTTP_") || skipHeader[k] { + if !strings.HasPrefix(k, "HTTP_") || k == "HTTP_HOST" { continue } r.Header.Add(strings.Replace(k[5:], "_", "-", -1), v) diff --git a/src/pkg/http/cgi/child_test.go b/src/pkg/http/cgi/child_test.go index d12947814..eee043bc9 100644 --- a/src/pkg/http/cgi/child_test.go +++ b/src/pkg/http/cgi/child_test.go @@ -28,23 +28,19 @@ func TestRequest(t *testing.T) { if err != nil { t.Fatalf("RequestFromMap: %v", err) } - if g, e := req.UserAgent, "goclient"; e != g { + if g, e := req.UserAgent(), "goclient"; e != g { t.Errorf("expected UserAgent %q; got %q", e, g) } if g, e := req.Method, "GET"; e != g { t.Errorf("expected Method %q; got %q", e, g) } - if g, e := req.Header.Get("User-Agent"), ""; e != g { - // Tests that we don't put recognized headers in the map - t.Errorf("expected User-Agent %q; got %q", e, g) - } if g, e := req.Header.Get("Content-Type"), "text/xml"; e != g { t.Errorf("expected Content-Type %q; got %q", e, g) } if g, e := req.ContentLength, int64(123); e != g { t.Errorf("expected ContentLength %d; got %d", e, g) } - if g, e := req.Referer, "elsewhere"; e != g { + if g, e := req.Referer(), "elsewhere"; e != g { t.Errorf("expected Referer %q; got %q", e, g) } if req.Header == nil { diff --git a/src/pkg/http/cgi/host.go b/src/pkg/http/cgi/host.go index 7ab3f9247..059fc758e 100644 --- a/src/pkg/http/cgi/host.go +++ b/src/pkg/http/cgi/host.go @@ -16,7 +16,6 @@ package cgi import ( "bufio" - "bytes" "exec" "fmt" "http" @@ -47,6 +46,12 @@ type Handler struct { Path string // path to the CGI executable Root string // root URI prefix of handler or empty for "/" + // Dir specifies the CGI executable's working directory. + // If Dir is empty, the base directory of Path is used. + // If Path has no base directory, the current working + // directory is used. + Dir string + Env []string // extra environment variables to set, if any, as "key=value" InheritEnv []string // environment variables to inherit from host, as "key" Logger *log.Logger // optional log for errors or nil to use log.Print @@ -106,20 +111,13 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { env = append(env, "HTTPS=on") } - if len(req.Cookie) > 0 { - b := new(bytes.Buffer) - for idx, c := range req.Cookie { - if idx > 0 { - b.Write([]byte("; ")) - } - fmt.Fprintf(b, "%s=%s", c.Name, c.Value) - } - env = append(env, "HTTP_COOKIE="+b.String()) - } - for k, v := range req.Header { k = strings.Map(upperCaseAndUnderscore, k) - env = append(env, "HTTP_"+k+"="+strings.Join(v, ", ")) + joinStr := ", " + if k == "COOKIE" { + joinStr = "; " + } + env = append(env, "HTTP_"+k+"="+strings.Join(v, joinStr)) } if req.ContentLength > 0 { @@ -133,11 +131,11 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { env = append(env, h.Env...) } - path := os.Getenv("PATH") - if path == "" { - path = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin" + envPath := os.Getenv("PATH") + if envPath == "" { + envPath = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin" } - env = append(env, "PATH="+path) + env = append(env, "PATH="+envPath) for _, e := range h.InheritEnv { if v := os.Getenv(e); v != "" { @@ -151,7 +149,13 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } } - cwd, pathBase := filepath.Split(h.Path) + var cwd, path string + if h.Dir != "" { + path = h.Path + cwd = h.Dir + } else { + cwd, path = filepath.Split(h.Path) + } if cwd == "" { cwd = "." } @@ -162,7 +166,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } cmd := &exec.Cmd{ - Path: pathBase, + Path: path, Args: append([]string{h.Path}, h.Args...), Dir: cwd, Env: env, @@ -205,7 +209,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if len(line) == 0 { break } - parts := strings.Split(string(line), ":", 2) + parts := strings.SplitN(string(line), ":", 2) if len(parts) < 2 { h.printf("cgi: bogus header line: %s", string(line)) continue diff --git a/src/pkg/http/cgi/host_test.go b/src/pkg/http/cgi/host_test.go index bbdb715cf..b08d8bbf6 100644 --- a/src/pkg/http/cgi/host_test.go +++ b/src/pkg/http/cgi/host_test.go @@ -13,8 +13,10 @@ import ( "http" "http/httptest" "os" + "path/filepath" "strings" "testing" + "runtime" ) func newRequest(httpreq string) *http.Request { @@ -46,7 +48,7 @@ readlines: } linesRead++ trimmedLine := strings.TrimRight(line, "\r\n") - split := strings.Split(trimmedLine, "=", 2) + split := strings.SplitN(trimmedLine, "=", 2) if len(split) != 2 { t.Fatalf("Unexpected %d parts from invalid line number %v: %q; existing map=%v", len(split), linesRead, line, m) @@ -301,3 +303,77 @@ func TestInternalRedirect(t *testing.T) { } runCgiTest(t, h, "GET /test.cgi?loc=/foo HTTP/1.0\nHost: example.com\n\n", expectedMap) } + +func TestDirUnix(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + + cwd, _ := os.Getwd() + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + Dir: cwd, + } + expectedMap := map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) + + cwd, _ = os.Getwd() + cwd = filepath.Join(cwd, "testdata") + h = &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap = map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestDirWindows(t *testing.T) { + if runtime.GOOS != "windows" { + return + } + + cgifile, _ := filepath.Abs("testdata/test.cgi") + + var perl string + var err os.Error + perl, err = exec.LookPath("perl") + if err != nil { + return + } + perl, _ = filepath.Abs(perl) + + cwd, _ := os.Getwd() + h := &Handler{ + Path: perl, + Root: "/test.cgi", + Dir: cwd, + Args: []string{cgifile}, + Env: []string{"SCRIPT_FILENAME=" + cgifile}, + } + expectedMap := map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) + + // If not specify Dir on windows, working directory should be + // base directory of perl. + cwd, _ = filepath.Split(perl) + if cwd != "" && cwd[len(cwd)-1] == filepath.Separator { + cwd = cwd[:len(cwd)-1] + } + h = &Handler{ + Path: perl, + Root: "/test.cgi", + Args: []string{cgifile}, + Env: []string{"SCRIPT_FILENAME=" + cgifile}, + } + expectedMap = map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) +} diff --git a/src/pkg/http/cgi/testdata/test.cgi b/src/pkg/http/cgi/testdata/test.cgi index a1b2ff893..36c107f76 100755 --- a/src/pkg/http/cgi/testdata/test.cgi +++ b/src/pkg/http/cgi/testdata/test.cgi @@ -6,9 +6,9 @@ # Test script run as a child process under cgi_test.go use strict; -use CGI; +use Cwd; -my $q = CGI->new; +my $q = MiniCGI->new; my $params = $q->Vars; if ($params->{"loc"}) { @@ -39,3 +39,50 @@ foreach my $k (sort keys %ENV) { $clean_env =~ s/[\n\r]//g; print "env-$k=$clean_env\n"; } + +# NOTE: don't call getcwd() for windows. +# msys return /c/go/src/... not C:\go\... +my $dir; +if ($^O eq 'MSWin32' || $^O eq 'msys') { + my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe'; + $cmd =~ s!\\!/!g; + $dir = `$cmd /c cd`; + chomp $dir; +} else { + $dir = getcwd(); +} +print "cwd=$dir\n"; + + +# A minimal version of CGI.pm, for people without the perl-modules +# package installed. (CGI.pm used to be part of the Perl core, but +# some distros now bundle perl-base and perl-modules separately...) +package MiniCGI; + +sub new { + my $class = shift; + return bless {}, $class; +} + +sub Vars { + my $self = shift; + my $pairs; + if ($ENV{CONTENT_LENGTH}) { + $pairs = do { local $/; <STDIN> }; + } else { + $pairs = $ENV{QUERY_STRING}; + } + my $vars = {}; + foreach my $kv (split(/&/, $pairs)) { + my ($k, $v) = split(/=/, $kv, 2); + $vars->{_urldecode($k)} = _urldecode($v); + } + return $vars; +} + +sub _urldecode { + my $v = shift; + $v =~ tr/+/ /; + $v =~ s/%([a-fA-F0-9][a-fA-F0-9])/pack("C", hex($1))/eg; + return $v; +} diff --git a/src/pkg/http/chunked.go b/src/pkg/http/chunked.go index 59121c5a2..6c23e691f 100644 --- a/src/pkg/http/chunked.go +++ b/src/pkg/http/chunked.go @@ -9,6 +9,7 @@ import ( "log" "os" "strconv" + "bufio" ) // NewChunkedWriter returns a new writer that translates writes into HTTP @@ -64,3 +65,13 @@ func (cw *chunkedWriter) Close() os.Error { _, err := io.WriteString(cw.Wire, "0\r\n") return err } + +// NewChunkedReader returns a new reader that translates the data read from r +// out of HTTP "chunked" format before returning it. +// The reader returns os.EOF when the final 0-length chunk is read. +// +// NewChunkedReader is not needed by normal applications. The http package +// automatically decodes chunking when reading response bodies. +func NewChunkedReader(r *bufio.Reader) io.Reader { + return &chunkedReader{r: r} +} diff --git a/src/pkg/http/client.go b/src/pkg/http/client.go index 71b037042..4f63b44f2 100644 --- a/src/pkg/http/client.go +++ b/src/pkg/http/client.go @@ -16,6 +16,11 @@ import ( // A Client is an HTTP client. Its zero value (DefaultClient) is a usable client // that uses DefaultTransport. +// +// The Client's Transport typically has internal state (cached +// TCP connections), so Clients should be reused instead of created as +// needed. Clients are safe for concurrent use by multiple goroutines. +// // Client is not yet very configurable. type Client struct { Transport RoundTripper // if nil, DefaultTransport is used @@ -36,6 +41,9 @@ var DefaultClient = &Client{} // RoundTripper is an interface representing the ability to execute a // single HTTP transaction, obtaining the Response for a given Request. +// +// A RoundTripper must be safe for concurrent use by multiple +// goroutines. type RoundTripper interface { // RoundTrip executes a single HTTP transaction, returning // the Response for the request req. RoundTrip should not @@ -173,7 +181,7 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err os.Error) // Add the Referer header. lastReq := via[len(via)-1] if lastReq.URL.Scheme != "https" { - req.Referer = lastReq.URL.String() + req.Header.Set("Referer", lastReq.URL.String()) } err = redirectChecker(req, via) @@ -190,7 +198,7 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err os.Error) if shouldRedirect(r.StatusCode) { r.Body.Close() if url = r.Header.Get("Location"); url == "" { - err = os.ErrorString(fmt.Sprintf("%d response missing Location header", r.StatusCode)) + err = os.NewError(fmt.Sprintf("%d response missing Location header", r.StatusCode)) break } base = req.URL @@ -207,7 +215,7 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err os.Error) func defaultCheckRedirect(req *Request, via []*Request) os.Error { if len(via) >= 10 { - return os.ErrorString("stopped after 10 redirects") + return os.NewError("stopped after 10 redirects") } return nil } diff --git a/src/pkg/http/client_test.go b/src/pkg/http/client_test.go index 9ef81d9d4..3b8558535 100644 --- a/src/pkg/http/client_test.go +++ b/src/pkg/http/client_test.go @@ -12,6 +12,7 @@ import ( "http/httptest" "io" "io/ioutil" + "net" "os" "strconv" "strings" @@ -149,7 +150,7 @@ func TestRedirects(t *testing.T) { n, _ := strconv.Atoi(r.FormValue("n")) // Test Referer header. (7 is arbitrary position to test at) if n == 7 { - if g, e := r.Referer, ts.URL+"/?n=6"; e != g { + if g, e := r.Referer(), ts.URL+"/?n=6"; e != g { t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g) } } @@ -243,3 +244,48 @@ func TestStreamingGet(t *testing.T) { t.Fatalf("at end expected EOF, got %v", err) } } + +type writeCountingConn struct { + net.Conn + count *int +} + +func (c *writeCountingConn) Write(p []byte) (int, os.Error) { + *c.count++ + return c.Conn.Write(p) +} + +// TestClientWrites verifies that client requests are buffered and we +// don't send a TCP packet per line of the http request + body. +func TestClientWrites(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + })) + defer ts.Close() + + writes := 0 + dialer := func(netz string, addr string) (net.Conn, os.Error) { + c, err := net.Dial(netz, addr) + if err == nil { + c = &writeCountingConn{c, &writes} + } + return c, err + } + c := &Client{Transport: &Transport{Dial: dialer}} + + _, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if writes != 1 { + t.Errorf("Get request did %d Write calls, want 1", writes) + } + + writes = 0 + _, err = c.PostForm(ts.URL, Values{"foo": {"bar"}}) + if err != nil { + t.Fatal(err) + } + if writes != 1 { + t.Errorf("Post request did %d Write calls, want 1", writes) + } +} diff --git a/src/pkg/http/cookie.go b/src/pkg/http/cookie.go index eb61a7001..fe70431bb 100644 --- a/src/pkg/http/cookie.go +++ b/src/pkg/http/cookie.go @@ -7,9 +7,6 @@ package http import ( "bytes" "fmt" - "io" - "os" - "sort" "strconv" "strings" "time" @@ -40,30 +37,25 @@ type Cookie struct { } // readSetCookies parses all "Set-Cookie" values from -// the header h, removes the successfully parsed values from the -// "Set-Cookie" key in h and returns the parsed Cookies. +// the header h and returns the successfully parsed Cookies. func readSetCookies(h Header) []*Cookie { cookies := []*Cookie{} - var unparsedLines []string for _, line := range h["Set-Cookie"] { - parts := strings.Split(strings.TrimSpace(line), ";", -1) + parts := strings.Split(strings.TrimSpace(line), ";") if len(parts) == 1 && parts[0] == "" { continue } parts[0] = strings.TrimSpace(parts[0]) j := strings.Index(parts[0], "=") if j < 0 { - unparsedLines = append(unparsedLines, line) continue } name, value := parts[0][:j], parts[0][j+1:] if !isCookieNameValid(name) { - unparsedLines = append(unparsedLines, line) continue } value, success := parseCookieValue(value) if !success { - unparsedLines = append(unparsedLines, line) continue } c := &Cookie{ @@ -134,77 +126,56 @@ func readSetCookies(h Header) []*Cookie { } cookies = append(cookies, c) } - h["Set-Cookie"] = unparsedLines, unparsedLines != nil return cookies } // SetCookie adds a Set-Cookie header to the provided ResponseWriter's headers. func SetCookie(w ResponseWriter, cookie *Cookie) { - var b bytes.Buffer - writeSetCookieToBuffer(&b, cookie) - w.Header().Add("Set-Cookie", b.String()) + w.Header().Add("Set-Cookie", cookie.String()) } -func writeSetCookieToBuffer(buf *bytes.Buffer, c *Cookie) { - fmt.Fprintf(buf, "%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) +// String returns the serialization of the cookie for use in a Cookie +// header (if only Name and Value are set) or a Set-Cookie response +// header (if other fields are set). +func (c *Cookie) String() string { + var b bytes.Buffer + fmt.Fprintf(&b, "%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) if len(c.Path) > 0 { - fmt.Fprintf(buf, "; Path=%s", sanitizeValue(c.Path)) + fmt.Fprintf(&b, "; Path=%s", sanitizeValue(c.Path)) } if len(c.Domain) > 0 { - fmt.Fprintf(buf, "; Domain=%s", sanitizeValue(c.Domain)) + fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(c.Domain)) } if len(c.Expires.Zone) > 0 { - fmt.Fprintf(buf, "; Expires=%s", c.Expires.Format(time.RFC1123)) + fmt.Fprintf(&b, "; Expires=%s", c.Expires.Format(time.RFC1123)) } if c.MaxAge > 0 { - fmt.Fprintf(buf, "; Max-Age=%d", c.MaxAge) + fmt.Fprintf(&b, "; Max-Age=%d", c.MaxAge) } else if c.MaxAge < 0 { - fmt.Fprintf(buf, "; Max-Age=0") + fmt.Fprintf(&b, "; Max-Age=0") } if c.HttpOnly { - fmt.Fprintf(buf, "; HttpOnly") + fmt.Fprintf(&b, "; HttpOnly") } if c.Secure { - fmt.Fprintf(buf, "; Secure") + fmt.Fprintf(&b, "; Secure") } + return b.String() } -// writeSetCookies writes the wire representation of the set-cookies -// to w. Each cookie is written on a separate "Set-Cookie: " line. -// This choice is made because HTTP parsers tend to have a limit on -// line-length, so it seems safer to place cookies on separate lines. -func writeSetCookies(w io.Writer, kk []*Cookie) os.Error { - if kk == nil { - return nil - } - lines := make([]string, 0, len(kk)) - var b bytes.Buffer - for _, c := range kk { - b.Reset() - writeSetCookieToBuffer(&b, c) - lines = append(lines, "Set-Cookie: "+b.String()+"\r\n") - } - sort.SortStrings(lines) - for _, l := range lines { - if _, err := io.WriteString(w, l); err != nil { - return err - } - } - return nil -} - -// readCookies parses all "Cookie" values from -// the header h, removes the successfully parsed values from the -// "Cookie" key in h and returns the parsed Cookies. -func readCookies(h Header) []*Cookie { +// readCookies parses all "Cookie" values from the header h and +// returns the successfully parsed Cookies. +// +// if filter isn't empty, only cookies of that name are returned +func readCookies(h Header, filter string) []*Cookie { cookies := []*Cookie{} lines, ok := h["Cookie"] if !ok { return cookies } - unparsedLines := []string{} + for _, line := range lines { - parts := strings.Split(strings.TrimSpace(line), ";", -1) + parts := strings.Split(strings.TrimSpace(line), ";") if len(parts) == 1 && parts[0] == "" { continue } @@ -215,50 +186,27 @@ func readCookies(h Header) []*Cookie { if len(parts[i]) == 0 { continue } - attr, val := parts[i], "" - if j := strings.Index(attr, "="); j >= 0 { - attr, val = attr[:j], attr[j+1:] + name, val := parts[i], "" + if j := strings.Index(name, "="); j >= 0 { + name, val = name[:j], name[j+1:] } - if !isCookieNameValid(attr) { + if !isCookieNameValid(name) { + continue + } + if filter != "" && filter != name { continue } val, success := parseCookieValue(val) if !success { continue } - cookies = append(cookies, &Cookie{Name: attr, Value: val}) + cookies = append(cookies, &Cookie{Name: name, Value: val}) parsedPairs++ } - if parsedPairs == 0 { - unparsedLines = append(unparsedLines, line) - } } - h["Cookie"] = unparsedLines, len(unparsedLines) > 0 return cookies } -// writeCookies writes the wire representation of the cookies to -// w. According to RFC 6265 section 5.4, writeCookies does not -// attach more than one Cookie header field. That means all -// cookies, if any, are written into the same line, separated by -// semicolon. -func writeCookies(w io.Writer, kk []*Cookie) os.Error { - if len(kk) == 0 { - return nil - } - var buf bytes.Buffer - fmt.Fprintf(&buf, "Cookie: ") - for i, c := range kk { - if i > 0 { - fmt.Fprintf(&buf, "; ") - } - fmt.Fprintf(&buf, "%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) - } - fmt.Fprintf(&buf, "\r\n") - _, err := w.Write(buf.Bytes()) - return err -} - func sanitizeName(n string) string { n = strings.Replace(n, "\n", "-", -1) n = strings.Replace(n, "\r", "-", -1) diff --git a/src/pkg/http/cookie_test.go b/src/pkg/http/cookie_test.go index 02e42226b..d7aeda0be 100644 --- a/src/pkg/http/cookie_test.go +++ b/src/pkg/http/cookie_test.go @@ -5,7 +5,6 @@ package http import ( - "bytes" "fmt" "json" "os" @@ -15,30 +14,31 @@ import ( ) var writeSetCookiesTests = []struct { - Cookies []*Cookie - Raw string + Cookie *Cookie + Raw string }{ { - []*Cookie{ - &Cookie{Name: "cookie-1", Value: "v$1"}, - &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600}, - &Cookie{Name: "cookie-3", Value: "three", Domain: ".example.com"}, - &Cookie{Name: "cookie-4", Value: "four", Path: "/restricted/"}, - }, - "Set-Cookie: cookie-1=v$1\r\n" + - "Set-Cookie: cookie-2=two; Max-Age=3600\r\n" + - "Set-Cookie: cookie-3=three; Domain=.example.com\r\n" + - "Set-Cookie: cookie-4=four; Path=/restricted/\r\n", + &Cookie{Name: "cookie-1", Value: "v$1"}, + "cookie-1=v$1", + }, + { + &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600}, + "cookie-2=two; Max-Age=3600", + }, + { + &Cookie{Name: "cookie-3", Value: "three", Domain: ".example.com"}, + "cookie-3=three; Domain=.example.com", + }, + { + &Cookie{Name: "cookie-4", Value: "four", Path: "/restricted/"}, + "cookie-4=four; Path=/restricted/", }, } func TestWriteSetCookies(t *testing.T) { for i, tt := range writeSetCookiesTests { - var w bytes.Buffer - writeSetCookies(&w, tt.Cookies) - seen := string(w.Bytes()) - if seen != tt.Raw { - t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.Raw, seen) + if g, e := tt.Cookie.String(), tt.Raw; g != e { + t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, e, g) continue } } @@ -73,7 +73,7 @@ func TestSetCookie(t *testing.T) { } } -var writeCookiesTests = []struct { +var addCookieTests = []struct { Cookies []*Cookie Raw string }{ @@ -83,7 +83,7 @@ var writeCookiesTests = []struct { }, { []*Cookie{&Cookie{Name: "cookie-1", Value: "v$1"}}, - "Cookie: cookie-1=v$1\r\n", + "cookie-1=v$1", }, { []*Cookie{ @@ -91,17 +91,18 @@ var writeCookiesTests = []struct { &Cookie{Name: "cookie-2", Value: "v$2"}, &Cookie{Name: "cookie-3", Value: "v$3"}, }, - "Cookie: cookie-1=v$1; cookie-2=v$2; cookie-3=v$3\r\n", + "cookie-1=v$1; cookie-2=v$2; cookie-3=v$3", }, } -func TestWriteCookies(t *testing.T) { - for i, tt := range writeCookiesTests { - var w bytes.Buffer - writeCookies(&w, tt.Cookies) - seen := string(w.Bytes()) - if seen != tt.Raw { - t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.Raw, seen) +func TestAddCookie(t *testing.T) { + for i, tt := range addCookieTests { + req, _ := NewRequest("GET", "http://example.com/", nil) + for _, c := range tt.Cookies { + req.AddCookie(c) + } + if g := req.Header.Get("Cookie"); g != tt.Raw { + t.Errorf("Test %d:\nwant: %s\n got: %s\n", i, tt.Raw, g) continue } } @@ -140,30 +141,61 @@ func toJSON(v interface{}) string { func TestReadSetCookies(t *testing.T) { for i, tt := range readSetCookiesTests { - c := readSetCookies(tt.Header) - if !reflect.DeepEqual(c, tt.Cookies) { - t.Errorf("#%d readSetCookies: have\n%s\nwant\n%s\n", i, toJSON(c), toJSON(tt.Cookies)) - continue + for n := 0; n < 2; n++ { // to verify readSetCookies doesn't mutate its input + c := readSetCookies(tt.Header) + if !reflect.DeepEqual(c, tt.Cookies) { + t.Errorf("#%d readSetCookies: have\n%s\nwant\n%s\n", i, toJSON(c), toJSON(tt.Cookies)) + continue + } } } } var readCookiesTests = []struct { Header Header + Filter string Cookies []*Cookie }{ { - Header{"Cookie": {"Cookie-1=v$1"}}, - []*Cookie{&Cookie{Name: "Cookie-1", Value: "v$1"}}, + Header{"Cookie": {"Cookie-1=v$1", "c2=v2"}}, + "", + []*Cookie{ + &Cookie{Name: "Cookie-1", Value: "v$1"}, + &Cookie{Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {"Cookie-1=v$1", "c2=v2"}}, + "c2", + []*Cookie{ + &Cookie{Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {"Cookie-1=v$1; c2=v2"}}, + "", + []*Cookie{ + &Cookie{Name: "Cookie-1", Value: "v$1"}, + &Cookie{Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {"Cookie-1=v$1; c2=v2"}}, + "c2", + []*Cookie{ + &Cookie{Name: "c2", Value: "v2"}, + }, }, } func TestReadCookies(t *testing.T) { for i, tt := range readCookiesTests { - c := readCookies(tt.Header) - if !reflect.DeepEqual(c, tt.Cookies) { - t.Errorf("#%d readCookies: have\n%s\nwant\n%s\n", i, toJSON(c), toJSON(tt.Cookies)) - continue + for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input + c := readCookies(tt.Header, tt.Filter) + if !reflect.DeepEqual(c, tt.Cookies) { + t.Errorf("#%d readCookies:\nhave: %s\nwant: %s\n", i, toJSON(c), toJSON(tt.Cookies)) + continue + } } } } diff --git a/src/pkg/http/fs.go b/src/pkg/http/fs.go index 28a0c51ef..0b830053a 100644 --- a/src/pkg/http/fs.go +++ b/src/pkg/http/fs.go @@ -11,6 +11,7 @@ import ( "io" "mime" "os" + "path" "path/filepath" "strconv" "strings" @@ -18,6 +19,38 @@ import ( "utf8" ) +// A Dir implements http.FileSystem using the native file +// system restricted to a specific directory tree. +type Dir string + +func (d Dir) Open(name string) (File, os.Error) { + if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 { + return nil, os.NewError("http: invalid character in file path") + } + f, err := os.Open(filepath.Join(string(d), filepath.FromSlash(path.Clean("/"+name)))) + if err != nil { + return nil, err + } + return f, nil +} + +// A FileSystem implements access to a collection of named files. +// The elements in a file path are separated by slash ('/', U+002F) +// characters, regardless of host operating system convention. +type FileSystem interface { + Open(name string) (File, os.Error) +} + +// A File is returned by a FileSystem's Open method and can be +// served by the FileServer implementation. +type File interface { + Close() os.Error + Stat() (*os.FileInfo, os.Error) + Readdir(count int) ([]os.FileInfo, os.Error) + Read([]byte) (int, os.Error) + Seek(offset int64, whence int) (int64, os.Error) +} + // Heuristic: b is text if it is valid UTF-8 and doesn't // contain any unprintable ASCII or Unicode characters. func isText(b []byte) bool { @@ -44,7 +77,7 @@ func isText(b []byte) bool { return true } -func dirList(w ResponseWriter, f *os.File) { +func dirList(w ResponseWriter, f File) { fmt.Fprintf(w, "<pre>\n") for { dirs, err := f.Readdir(100) @@ -63,7 +96,8 @@ func dirList(w ResponseWriter, f *os.File) { fmt.Fprintf(w, "</pre>\n") } -func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { +// name is '/'-separated, not filepath.Separator. +func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) { const indexPage = "/index.html" // redirect .../index.html to .../ @@ -72,7 +106,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { return } - f, err := os.Open(name) + f, err := fs.Open(name) if err != nil { // TODO expose actual error? NotFound(w, r) @@ -113,7 +147,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { // use contents of index.html for directory, if present if d.IsDirectory() { index := name + filepath.FromSlash(indexPage) - ff, err := os.Open(index) + ff, err := fs.Open(index) if err == nil { defer ff.Close() dd, err := ff.Stat() @@ -157,7 +191,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { // TODO(adg): handle multiple ranges ranges, err := parseRange(r.Header.Get("Range"), size) if err == nil && len(ranges) > 1 { - err = os.ErrorString("multiple ranges not supported") + err = os.NewError("multiple ranges not supported") } if err != nil { Error(w, err.String(), StatusRequestedRangeNotSatisfiable) @@ -188,28 +222,26 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { // ServeFile replies to the request with the contents of the named file or directory. func ServeFile(w ResponseWriter, r *Request, name string) { - serveFile(w, r, name, false) + serveFile(w, r, Dir(name), "", false) } type fileHandler struct { - root string - prefix string + root FileSystem } // FileServer returns a handler that serves HTTP requests // with the contents of the file system rooted at root. -// It strips prefix from the incoming requests before -// looking up the file name in the file system. -func FileServer(root, prefix string) Handler { return &fileHandler{root, prefix} } +// +// To use the operating system's file system implementation, +// use http.Dir: +// +// http.Handle("/", http.FileServer(http.Dir("/tmp"))) +func FileServer(root FileSystem) Handler { + return &fileHandler{root} +} func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) { - path := r.URL.Path - if !strings.HasPrefix(path, f.prefix) { - NotFound(w, r) - return - } - path = path[len(f.prefix):] - serveFile(w, r, filepath.Join(f.root, filepath.FromSlash(path)), true) + serveFile(w, r, f.root, path.Clean(r.URL.Path), true) } // httpRange specifies the byte range to be sent to the client. @@ -227,7 +259,7 @@ func parseRange(s string, size int64) ([]httpRange, os.Error) { return nil, os.NewError("invalid range") } var ranges []httpRange - for _, ra := range strings.Split(s[len(b):], ",", -1) { + for _, ra := range strings.Split(s[len(b):], ",") { i := strings.Index(ra, "-") if i < 0 { return nil, os.NewError("invalid range") diff --git a/src/pkg/http/fs_test.go b/src/pkg/http/fs_test.go index 554053449..dbbdf05bd 100644 --- a/src/pkg/http/fs_test.go +++ b/src/pkg/http/fs_test.go @@ -85,6 +85,72 @@ func TestServeFile(t *testing.T) { } } +type testFileSystem struct { + open func(name string) (File, os.Error) +} + +func (fs *testFileSystem) Open(name string) (File, os.Error) { + return fs.open(name) +} + +func TestFileServerCleans(t *testing.T) { + ch := make(chan string, 1) + fs := FileServer(&testFileSystem{func(name string) (File, os.Error) { + ch <- name + return nil, os.ENOENT + }}) + tests := []struct { + reqPath, openArg string + }{ + {"/foo.txt", "/foo.txt"}, + {"//foo.txt", "/foo.txt"}, + {"/../foo.txt", "/foo.txt"}, + } + req, _ := NewRequest("GET", "http://example.com", nil) + for n, test := range tests { + rec := httptest.NewRecorder() + req.URL.Path = test.reqPath + fs.ServeHTTP(rec, req) + if got := <-ch; got != test.openArg { + t.Errorf("test %d: got %q, want %q", n, got, test.openArg) + } + } +} + +func TestDirJoin(t *testing.T) { + wfi, err := os.Stat("/etc/hosts") + if err != nil { + t.Logf("skipping test; no /etc/hosts file") + return + } + test := func(d Dir, name string) { + f, err := d.Open(name) + if err != nil { + t.Fatalf("open of %s: %v", name, err) + } + defer f.Close() + gfi, err := f.Stat() + if err != nil { + t.Fatalf("stat of %s: %v", err) + } + if gfi.Ino != wfi.Ino { + t.Errorf("%s got different inode") + } + } + test(Dir("/etc/"), "/hosts") + test(Dir("/etc/"), "hosts") + test(Dir("/etc/"), "../../../../hosts") + test(Dir("/etc"), "/hosts") + test(Dir("/etc"), "hosts") + test(Dir("/etc"), "../../../../hosts") + + // Not really directories, but since we use this trick in + // ServeFile, test it: + test(Dir("/etc/hosts"), "") + test(Dir("/etc/hosts"), "/") + test(Dir("/etc/hosts"), "../") +} + func TestServeFileContentType(t *testing.T) { const ctype = "icecream/chocolate" override := false diff --git a/src/pkg/http/header.go b/src/pkg/http/header.go index 95140b01f..08b077130 100644 --- a/src/pkg/http/header.go +++ b/src/pkg/http/header.go @@ -56,15 +56,12 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) os.Error { keys = append(keys, k) } } - sort.SortStrings(keys) + sort.Strings(keys) for _, k := range keys { for _, v := range h[k] { v = strings.Replace(v, "\n", " ", -1) v = strings.Replace(v, "\r", " ", -1) v = strings.TrimSpace(v) - if v == "" { - continue - } if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { return err } diff --git a/src/pkg/http/header_test.go b/src/pkg/http/header_test.go index 7e24cb069..ccdee8a97 100644 --- a/src/pkg/http/header_test.go +++ b/src/pkg/http/header_test.go @@ -57,6 +57,16 @@ var headerWriteTests = []struct { map[string]bool{"Content-Length": true, "Expires": true, "Content-Encoding": true}, "", }, + { + Header{ + "Nil": nil, + "Empty": {}, + "Blank": {""}, + "Double-Blank": {"", ""}, + }, + nil, + "Blank: \r\nDouble-Blank: \r\nDouble-Blank: \r\n", + }, } func TestHeaderWrite(t *testing.T) { diff --git a/src/pkg/http/persist.go b/src/pkg/http/persist.go index 62f9ff1b5..78bf9058f 100644 --- a/src/pkg/http/persist.go +++ b/src/pkg/http/persist.go @@ -24,6 +24,9 @@ var ( // to regain control over the connection. ServerConn supports pipe-lining, // i.e. requests can be read out of sync (but in the same order) while the // respective responses are sent. +// +// ServerConn is low-level and should not be needed by most applications. +// See Server. type ServerConn struct { lk sync.Mutex // read-write protects the following fields c net.Conn @@ -211,6 +214,9 @@ func (sc *ServerConn) Write(req *Request, resp *Response) os.Error { // connection, while respecting the HTTP keepalive logic. ClientConn // supports hijacking the connection calling Hijack to // regain control of the underlying net.Conn and deal with it as desired. +// +// ClientConn is low-level and should not be needed by most applications. +// See Client. type ClientConn struct { lk sync.Mutex // read-write protects the following fields c net.Conn diff --git a/src/pkg/http/readrequest_test.go b/src/pkg/http/readrequest_test.go index d93e573f5..79f8de70d 100644 --- a/src/pkg/http/readrequest_test.go +++ b/src/pkg/http/readrequest_test.go @@ -13,11 +13,15 @@ import ( ) type reqTest struct { - Raw string - Req Request - Body string + Raw string + Req *Request + Body string + Error string } +var noError = "" +var noBody = "" + var reqTests = []reqTest{ // Baseline test; All Request fields included for template use { @@ -33,7 +37,7 @@ var reqTests = []reqTest{ "Proxy-Connection: keep-alive\r\n\r\n" + "abcdef\n???", - Request{ + &Request{ Method: "GET", RawURL: "http://www.techcrunch.com/", URL: &URL{ @@ -58,16 +62,43 @@ var reqTests = []reqTest{ "Keep-Alive": {"300"}, "Proxy-Connection": {"keep-alive"}, "Content-Length": {"7"}, + "User-Agent": {"Fake"}, }, Close: false, ContentLength: 7, Host: "www.techcrunch.com", - Referer: "", - UserAgent: "Fake", Form: Values{}, }, "abcdef\n", + + noError, + }, + + // GET request with no body (the normal case) + { + "GET / HTTP/1.1\r\n" + + "Host: foo.com\r\n\r\n", + + &Request{ + Method: "GET", + RawURL: "/", + URL: &URL{ + Raw: "/", + Path: "/", + RawPath: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Close: false, + ContentLength: 0, + Host: "foo.com", + Form: Values{}, + }, + + noBody, + noError, }, // Tests that we don't parse a path that looks like a @@ -76,7 +107,7 @@ var reqTests = []reqTest{ "GET //user@host/is/actually/a/path/ HTTP/1.1\r\n" + "Host: test\r\n\r\n", - Request{ + &Request{ Method: "GET", RawURL: "//user@host/is/actually/a/path/", URL: &URL{ @@ -95,14 +126,31 @@ var reqTests = []reqTest{ ProtoMinor: 1, Header: Header{}, Close: false, - ContentLength: -1, + ContentLength: 0, Host: "test", - Referer: "", - UserAgent: "", Form: Values{}, }, - "", + noBody, + noError, + }, + + // Tests a bogus abs_path on the Request-Line (RFC 2616 section 5.1.2) + { + "GET ../../../../etc/passwd HTTP/1.1\r\n" + + "Host: test\r\n\r\n", + nil, + noBody, + "parse ../../../../etc/passwd: invalid URI for request", + }, + + // Tests missing URL: + { + "GET HTTP/1.1\r\n" + + "Host: test\r\n\r\n", + nil, + noBody, + "parse : empty url", }, } @@ -113,12 +161,14 @@ func TestReadRequest(t *testing.T) { braw.WriteString(tt.Raw) req, err := ReadRequest(bufio.NewReader(&braw)) if err != nil { - t.Errorf("#%d: %s", i, err) + if err.String() != tt.Error { + t.Errorf("#%d: error %q, want error %q", i, err.String(), tt.Error) + } continue } rbody := req.Body req.Body = nil - diff(t, fmt.Sprintf("#%d Request", i), req, &tt.Req) + diff(t, fmt.Sprintf("#%d Request", i), req, tt.Req) var bout bytes.Buffer if rbody != nil { io.Copy(&bout, rbody) diff --git a/src/pkg/http/request.go b/src/pkg/http/request.go index bdc3a7e4f..2917cc1e6 100644 --- a/src/pkg/http/request.go +++ b/src/pkg/http/request.go @@ -35,13 +35,15 @@ const ( // ErrMissingFile is returned by FormFile when the provided file field name // is either not present in the request or not a file field. -var ErrMissingFile = os.ErrorString("http: no such file") +var ErrMissingFile = os.NewError("http: no such file") // HTTP request parsing errors. type ProtocolError struct { - os.ErrorString + ErrorString string } +func (err *ProtocolError) String() string { return err.ErrorString } + var ( ErrLineTooLong = &ProtocolError{"header line too long"} ErrHeaderTooLong = &ProtocolError{"header too long"} @@ -60,10 +62,10 @@ type badStringError struct { func (e *badStringError) String() string { return fmt.Sprintf("%s %q", e.what, e.str) } -var reqExcludeHeader = map[string]bool{ +// Headers that Request.Write handles itself and should be skipped. +var reqWriteExcludeHeader = map[string]bool{ "Host": true, "User-Agent": true, - "Referer": true, "Content-Length": true, "Transfer-Encoding": true, "Trailer": true, @@ -102,9 +104,6 @@ type Request struct { // following a hyphen uppercase and the rest lowercase. Header Header - // Cookie records the HTTP cookies sent with the request. - Cookie []*Cookie - // The message body. Body io.ReadCloser @@ -125,21 +124,6 @@ type Request struct { // or the host name given in the URL itself. Host string - // The referring URL, if sent in the request. - // - // Referer is misspelled as in the request itself, - // a mistake from the earliest days of HTTP. - // This value can also be fetched from the Header map - // as Header["Referer"]; the benefit of making it - // available as a structure field is that the compiler - // can diagnose programs that use the alternate - // (correct English) spelling req.Referrer but cannot - // diagnose programs that use Header["Referrer"]. - Referer string - - // The User-Agent: header string, if sent in the request. - UserAgent string - // The parsed form. Only available after ParseForm is called. Form Values @@ -176,6 +160,52 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } +// UserAgent returns the client's User-Agent, if sent in the request. +func (r *Request) UserAgent() string { + return r.Header.Get("User-Agent") +} + +// Cookies parses and returns the HTTP cookies sent with the request. +func (r *Request) Cookies() []*Cookie { + return readCookies(r.Header, "") +} + +var ErrNoCookie = os.NewError("http: named cookied not present") + +// Cookie returns the named cookie provided in the request or +// ErrNoCookie if not found. +func (r *Request) Cookie(name string) (*Cookie, os.Error) { + for _, c := range readCookies(r.Header, name) { + return c, nil + } + return nil, ErrNoCookie +} + +// AddCookie adds a cookie to the request. Per RFC 6265 section 5.4, +// AddCookie does not attach more than one Cookie header field. That +// means all cookies, if any, are written into the same line, +// separated by semicolon. +func (r *Request) AddCookie(c *Cookie) { + s := fmt.Sprintf("%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) + if c := r.Header.Get("Cookie"); c != "" { + r.Header.Set("Cookie", c+"; "+s) + } else { + r.Header.Set("Cookie", s) + } +} + +// Referer returns the referring URL, if sent in the request. +// +// Referer is misspelled as in the request itself, a mistake from the +// earliest days of HTTP. This value can also be fetched from the +// Header map as Header["Referer"]; the benefit of making it available +// as a method is that the compiler can diagnose programs that use the +// alternate (correct English) spelling req.Referrer() but cannot +// diagnose programs that use Header["Referrer"]. +func (r *Request) Referer() string { + return r.Header.Get("Referer") +} + // multipartByReader is a sentinel value. // Its presence in Request.MultipartForm indicates that parsing of the request // body has been handed off to a MultipartReader instead of ParseMultipartFrom. @@ -188,7 +218,7 @@ var multipartByReader = &multipart.Form{ // multipart/form-data POST request, else returns nil and an error. // Use this function instead of ParseMultipartForm to // process the request body as a stream. -func (r *Request) MultipartReader() (multipart.Reader, os.Error) { +func (r *Request) MultipartReader() (*multipart.Reader, os.Error) { if r.MultipartForm == multipartByReader { return nil, os.NewError("http: MultipartReader called twice") } @@ -199,7 +229,7 @@ func (r *Request) MultipartReader() (multipart.Reader, os.Error) { return r.multipartReader() } -func (r *Request) multipartReader() (multipart.Reader, os.Error) { +func (r *Request) multipartReader() (*multipart.Reader, os.Error) { v := r.Header.Get("Content-Type") if v == "" { return nil, ErrNotMultipart @@ -230,10 +260,7 @@ const defaultUserAgent = "Go http package" // Host // RawURL, if non-empty, or else URL // Method (defaults to "GET") -// UserAgent (defaults to defaultUserAgent) -// Referer -// Header (only keys not already in this list) -// Cookie +// Header // ContentLength // TransferEncoding // Body @@ -277,13 +304,22 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error { } } - fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), uri) + bw := bufio.NewWriter(w) + fmt.Fprintf(bw, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), uri) // Header lines - fmt.Fprintf(w, "Host: %s\r\n", host) - fmt.Fprintf(w, "User-Agent: %s\r\n", valueOrDefault(req.UserAgent, defaultUserAgent)) - if req.Referer != "" { - fmt.Fprintf(w, "Referer: %s\r\n", req.Referer) + fmt.Fprintf(bw, "Host: %s\r\n", host) + + // Use the defaultUserAgent unless the Header contains one, which + // may be blank to not send the header. + userAgent := defaultUserAgent + if req.Header != nil { + if ua := req.Header["User-Agent"]; len(ua) > 0 { + userAgent = ua[0] + } + } + if userAgent != "" { + fmt.Fprintf(bw, "User-Agent: %s\r\n", userAgent) } // Process Body,ContentLength,Close,Trailer @@ -291,35 +327,25 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error { if err != nil { return err } - err = tw.WriteHeader(w) + err = tw.WriteHeader(bw) if err != nil { return err } // TODO: split long values? (If so, should share code with Conn.Write) - // TODO: if Header includes values for Host, User-Agent, or Referer, this - // may conflict with the User-Agent or Referer headers we add manually. - // One solution would be to remove the Host, UserAgent, and Referer fields - // from Request, and introduce Request methods along the lines of - // Response.{GetHeader,AddHeader} and string constants for "Host", - // "User-Agent" and "Referer". - err = req.Header.WriteSubset(w, reqExcludeHeader) + err = req.Header.WriteSubset(bw, reqWriteExcludeHeader) if err != nil { return err } - if err = writeCookies(w, req.Cookie); err != nil { - return err - } - - io.WriteString(w, "\r\n") + io.WriteString(bw, "\r\n") // Write body and trailer - err = tw.WriteBody(w) + err = tw.WriteBody(bw) if err != nil { return err } - + bw.Flush() return nil } @@ -402,10 +428,6 @@ type chunkedReader struct { err os.Error } -func newChunkedReader(r *bufio.Reader) *chunkedReader { - return &chunkedReader{r: r} -} - func (cr *chunkedReader) beginChunk() { // chunk-size CRLF var line string @@ -485,13 +507,6 @@ func NewRequest(method, url string, body io.Reader) (*Request, os.Error) { req.ContentLength = int64(v.Len()) case *bytes.Buffer: req.ContentLength = int64(v.Len()) - default: - req.ContentLength = -1 // chunked - } - if req.ContentLength == 0 { - // To prevent chunking and disambiguate this - // from the default ContentLength zero value. - req.TransferEncoding = []string{"identity"} } } @@ -524,7 +539,7 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { } var f []string - if f = strings.Split(s, " ", 3); len(f) < 3 { + if f = strings.SplitN(s, " ", 3); len(f) < 3 { return nil, &badStringError{"malformed HTTP request", s} } req.Method, req.RawURL, req.Proto = f[0], f[1], f[2] @@ -559,13 +574,6 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { fixPragmaCacheControl(req.Header) - // Pull out useful fields as a convenience to clients. - req.Referer = req.Header.Get("Referer") - req.Header.Del("Referer") - - req.UserAgent = req.Header.Get("User-Agent") - req.Header.Del("User-Agent") - // TODO: Parse specific header values: // Accept // Accept-Encoding @@ -597,8 +605,6 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { return nil, err } - req.Cookie = readCookies(req.Header) - return req, nil } @@ -652,11 +658,11 @@ func ParseQuery(query string) (m Values, err os.Error) { } func parseQuery(m Values, query string) (err os.Error) { - for _, kv := range strings.Split(query, "&", -1) { + for _, kv := range strings.Split(query, "&") { if len(kv) == 0 { continue } - kvPair := strings.Split(kv, "=", 2) + kvPair := strings.SplitN(kv, "=", 2) var key, value string var e os.Error @@ -690,10 +696,10 @@ func (r *Request) ParseForm() (err os.Error) { } if r.Method == "POST" { if r.Body == nil { - return os.ErrorString("missing form body") + return os.NewError("missing form body") } ct := r.Header.Get("Content-Type") - switch strings.Split(ct, ";", 2)[0] { + switch strings.SplitN(ct, ";", 2)[0] { case "text/plain", "application/x-www-form-urlencoded", "": const maxFormSize = int64(10 << 20) // 10 MB is a lot of text. b, e := ioutil.ReadAll(io.LimitReader(r.Body, maxFormSize+1)) diff --git a/src/pkg/http/requestwrite_test.go b/src/pkg/http/requestwrite_test.go index 98fbcf459..0052c0cfc 100644 --- a/src/pkg/http/requestwrite_test.go +++ b/src/pkg/http/requestwrite_test.go @@ -6,6 +6,7 @@ package http import ( "bytes" + "fmt" "io" "io/ioutil" "os" @@ -15,7 +16,7 @@ import ( type reqWriteTest struct { Req Request - Body []byte + Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body Raw string RawProxy string } @@ -47,13 +48,12 @@ var reqWriteTests = []reqWriteTest{ "Accept-Language": {"en-us,en;q=0.5"}, "Keep-Alive": {"300"}, "Proxy-Connection": {"keep-alive"}, + "User-Agent": {"Fake"}, }, - Body: nil, - Close: false, - Host: "www.techcrunch.com", - Referer: "", - UserAgent: "Fake", - Form: map[string][]string{}, + Body: nil, + Close: false, + Host: "www.techcrunch.com", + Form: map[string][]string{}, }, nil, @@ -99,13 +99,13 @@ var reqWriteTests = []reqWriteTest{ "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + - "6\r\nabcdef\r\n0\r\n\r\n", + chunk("abcdef") + chunk(""), "GET http://www.google.com/search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + - "6\r\nabcdef\r\n0\r\n\r\n", + chunk("abcdef") + chunk(""), }, // HTTP/1.1 POST => chunked coding; body; empty trailer { @@ -130,14 +130,14 @@ var reqWriteTests = []reqWriteTest{ "User-Agent: Go http package\r\n" + "Connection: close\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + - "6\r\nabcdef\r\n0\r\n\r\n", + chunk("abcdef") + chunk(""), "POST http://www.google.com/search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "Connection: close\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + - "6\r\nabcdef\r\n0\r\n\r\n", + chunk("abcdef") + chunk(""), }, // HTTP/1.1 POST with Content-Length, no chunking @@ -225,13 +225,75 @@ var reqWriteTests = []reqWriteTest{ "User-Agent: Go http package\r\n" + "\r\n", }, + + // Request with a 0 ContentLength and a 0 byte body. + { + Request{ + Method: "POST", + RawURL: "/", + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + + func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) }, + + "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "\r\n", + + "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "\r\n", + }, + + // Request with a 0 ContentLength and a 1 byte body. + { + Request{ + Method: "POST", + RawURL: "/", + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + + func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 1)) }, + + "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("x") + chunk(""), + + "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("x") + chunk(""), + }, } func TestRequestWrite(t *testing.T) { for i := range reqWriteTests { tt := &reqWriteTests[i] + + setBody := func() { + switch b := tt.Body.(type) { + case []byte: + tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(b)) + case func() io.ReadCloser: + tt.Req.Body = b() + } + } if tt.Body != nil { - tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(tt.Body)) + setBody() + } + if tt.Req.Header == nil { + tt.Req.Header = make(Header) } var braw bytes.Buffer err := tt.Req.Write(&braw) @@ -246,7 +308,7 @@ func TestRequestWrite(t *testing.T) { } if tt.Body != nil { - tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(tt.Body)) + setBody() } var praw bytes.Buffer err = tt.Req.WriteProxy(&praw) @@ -278,41 +340,30 @@ func (rc *closeChecker) Close() os.Error { func TestRequestWriteClosesBody(t *testing.T) { rc := &closeChecker{Reader: strings.NewReader("my body")} req, _ := NewRequest("POST", "http://foo.com/", rc) - if g, e := req.ContentLength, int64(-1); g != e { - t.Errorf("got req.ContentLength %d, want %d", g, e) + if req.ContentLength != 0 { + t.Errorf("got req.ContentLength %d, want 0", req.ContentLength) } buf := new(bytes.Buffer) req.Write(buf) if !rc.closed { t.Error("body not closed after write") } - if g, e := buf.String(), "POST / HTTP/1.1\r\nHost: foo.com\r\nUser-Agent: Go http package\r\nTransfer-Encoding: chunked\r\n\r\n7\r\nmy body\r\n0\r\n\r\n"; g != e { - t.Errorf("write:\n got: %s\nwant: %s", g, e) + expected := "POST / HTTP/1.1\r\n" + + "Host: foo.com\r\n" + + "User-Agent: Go http package\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + // TODO: currently we don't buffer before chunking, so we get a + // single "m" chunk before the other chunks, as this was the 1-byte + // read from our MultiReader where we stiched the Body back together + // after sniffing whether the Body was 0 bytes or not. + chunk("m") + + chunk("y body") + + chunk("") + if buf.String() != expected { + t.Errorf("write:\n got: %s\nwant: %s", buf.String(), expected) } } -func TestZeroLengthNewRequest(t *testing.T) { - var buf bytes.Buffer - - // Writing with default identity encoding - req, _ := NewRequest("PUT", "http://foo.com/", strings.NewReader("")) - if len(req.TransferEncoding) == 0 || req.TransferEncoding[0] != "identity" { - t.Fatalf("got req.TransferEncoding of %v, want %v", req.TransferEncoding, []string{"identity"}) - } - if g, e := req.ContentLength, int64(0); g != e { - t.Errorf("got req.ContentLength %d, want %d", g, e) - } - req.Write(&buf) - if g, e := buf.String(), "PUT / HTTP/1.1\r\nHost: foo.com\r\nUser-Agent: Go http package\r\nContent-Length: 0\r\n\r\n"; g != e { - t.Errorf("identity write:\n got: %s\nwant: %s", g, e) - } - - // Overriding identity encoding and forcing chunked. - req, _ = NewRequest("PUT", "http://foo.com/", strings.NewReader("")) - req.TransferEncoding = nil - buf.Reset() - req.Write(&buf) - if g, e := buf.String(), "PUT / HTTP/1.1\r\nHost: foo.com\r\nUser-Agent: Go http package\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"; g != e { - t.Errorf("chunked write:\n got: %s\nwant: %s", g, e) - } +func chunk(s string) string { + return fmt.Sprintf("%x\r\n%s\r\n", len(s), s) } diff --git a/src/pkg/http/response.go b/src/pkg/http/response.go index 42e60c1f6..915327a69 100644 --- a/src/pkg/http/response.go +++ b/src/pkg/http/response.go @@ -40,9 +40,6 @@ type Response struct { // Keys in the map are canonicalized (see CanonicalHeaderKey). Header Header - // SetCookie records the Set-Cookie requests sent with the response. - SetCookie []*Cookie - // Body represents the response body. Body io.ReadCloser @@ -71,6 +68,11 @@ type Response struct { Request *Request } +// Cookies parses and returns the cookies set in the Set-Cookie headers. +func (r *Response) Cookies() []*Cookie { + return readSetCookies(r.Header) +} + // ReadResponse reads and returns an HTTP response from r. The // req parameter specifies the Request that corresponds to // this Response. Clients must call resp.Body.Close when finished @@ -93,7 +95,7 @@ func ReadResponse(r *bufio.Reader, req *Request) (resp *Response, err os.Error) } return nil, err } - f := strings.Split(line, " ", 3) + f := strings.SplitN(line, " ", 3) if len(f) < 2 { return nil, &badStringError{"malformed HTTP response", line} } @@ -127,8 +129,6 @@ func ReadResponse(r *bufio.Reader, req *Request) (resp *Response, err os.Error) return nil, err } - resp.SetCookie = readSetCookies(resp.Header) - return resp, nil } @@ -200,10 +200,6 @@ func (resp *Response) Write(w io.Writer) os.Error { return err } - if err = writeSetCookies(w, resp.SetCookie); err != nil { - return err - } - // End-of-header io.WriteString(w, "\r\n") diff --git a/src/pkg/http/reverseproxy.go b/src/pkg/http/reverseproxy.go index 9a9e21599..e4ce1e34c 100644 --- a/src/pkg/http/reverseproxy.go +++ b/src/pkg/http/reverseproxy.go @@ -92,10 +92,6 @@ func (p *ReverseProxy) ServeHTTP(rw ResponseWriter, req *Request) { } } - for _, cookie := range res.SetCookie { - SetCookie(rw, cookie) - } - rw.WriteHeader(res.StatusCode) if res.Body != nil { diff --git a/src/pkg/http/reverseproxy_test.go b/src/pkg/http/reverseproxy_test.go index d7bcde90d..b2dd24633 100644 --- a/src/pkg/http/reverseproxy_test.go +++ b/src/pkg/http/reverseproxy_test.go @@ -17,6 +17,9 @@ func TestReverseProxy(t *testing.T) { const backendResponse = "I am the backend" const backendStatus = 404 backend := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if len(r.TransferEncoding) > 0 { + t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) + } if r.Header.Get("X-Forwarded-For") == "" { t.Errorf("didn't get X-Forwarded-For header") } @@ -49,10 +52,10 @@ func TestReverseProxy(t *testing.T) { if g, e := res.Header.Get("X-Foo"), "bar"; g != e { t.Errorf("got X-Foo %q; expected %q", g, e) } - if g, e := len(res.SetCookie), 1; g != e { + if g, e := len(res.Header["Set-Cookie"]), 1; g != e { t.Fatalf("got %d SetCookies, want %d", g, e) } - if cookie := res.SetCookie[0]; cookie.Name != "flavor" { + if cookie := res.Cookies()[0]; cookie.Name != "flavor" { t.Errorf("unexpected cookie %q", cookie.Name) } bodyBytes, _ := ioutil.ReadAll(res.Body) diff --git a/src/pkg/http/serve_test.go b/src/pkg/http/serve_test.go index dc4594a79..55a9cbf70 100644 --- a/src/pkg/http/serve_test.go +++ b/src/pkg/http/serve_test.go @@ -373,11 +373,8 @@ func TestIdentityResponse(t *testing.T) { } } -// TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive. -func TestServeHTTP10Close(t *testing.T) { - s := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - ServeFile(w, r, "testdata/file") - })) +func testTcpConnectionCloses(t *testing.T, req string, h Handler) { + s := httptest.NewServer(h) defer s.Close() conn, err := net.Dial("tcp", s.Listener.Addr().String()) @@ -386,7 +383,7 @@ func TestServeHTTP10Close(t *testing.T) { } defer conn.Close() - _, err = fmt.Fprint(conn, "GET / HTTP/1.0\r\n\r\n") + _, err = fmt.Fprint(conn, req) if err != nil { t.Fatal("print error:", err) } @@ -414,6 +411,27 @@ func TestServeHTTP10Close(t *testing.T) { success <- true } +// TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive. +func TestServeHTTP10Close(t *testing.T) { + testTcpConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "testdata/file") + })) +} + +// TestHandlersCanSetConnectionClose verifies that handlers can force a connection to close, +// even for HTTP/1.1 requests. +func TestHandlersCanSetConnectionClose11(t *testing.T) { + testTcpConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + })) +} + +func TestHandlersCanSetConnectionClose10(t *testing.T) { + testTcpConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + })) +} + func TestSetsRemoteAddr(t *testing.T) { ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.RemoteAddr) @@ -522,7 +540,12 @@ func TestHeadResponses(t *testing.T) { func TestTLSServer(t *testing.T) { ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { - fmt.Fprintf(w, "tls=%v", r.TLS != nil) + if r.TLS != nil { + w.Header().Set("X-TLS-Set", "true") + if r.TLS.HandshakeComplete { + w.Header().Set("X-TLS-HandshakeComplete", "true") + } + } })) defer ts.Close() if !strings.HasPrefix(ts.URL, "https://") { @@ -530,20 +553,17 @@ func TestTLSServer(t *testing.T) { } res, err := Get(ts.URL) if err != nil { - t.Error(err) + t.Fatal(err) } if res == nil { t.Fatalf("got nil Response") } - if res.Body == nil { - t.Fatalf("got nil Response.Body") - } - body, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Error(err) + defer res.Body.Close() + if res.Header.Get("X-TLS-Set") != "true" { + t.Errorf("expected X-TLS-Set response header") } - if e, g := "tls=true", string(body); e != g { - t.Errorf("expected body %q; got %q", e, g) + if res.Header.Get("X-TLS-HandshakeComplete") != "true" { + t.Errorf("expected X-TLS-HandshakeComplete header") } } @@ -781,6 +801,45 @@ func TestHandlerPanic(t *testing.T) { } } +func TestNoDate(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header()["Date"] = nil + })) + defer ts.Close() + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + _, present := res.Header["Date"] + if present { + t.Fatalf("Expected no Date header; got %v", res.Header["Date"]) + } +} + +func TestStripPrefix(t *testing.T) { + h := HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("X-Path", r.URL.Path) + }) + ts := httptest.NewServer(StripPrefix("/foo", h)) + defer ts.Close() + + res, err := Get(ts.URL + "/foo/bar") + if err != nil { + t.Fatal(err) + } + if g, e := res.Header.Get("X-Path"), "/bar"; g != e { + t.Errorf("test 1: got %s, want %s", g, e) + } + + res, err = Get(ts.URL + "/bar") + if err != nil { + t.Fatal(err) + } + if g, e := res.StatusCode, 404; g != e { + t.Errorf("test 2: got status %v, want %v", g, e) + } +} + type errorListener struct { errs []os.Error } diff --git a/src/pkg/http/server.go b/src/pkg/http/server.go index d4638f127..ab960f4f0 100644 --- a/src/pkg/http/server.go +++ b/src/pkg/http/server.go @@ -20,7 +20,7 @@ import ( "net" "os" "path" - "runtime" + "runtime/debug" "strconv" "strings" "sync" @@ -152,6 +152,7 @@ func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) { c.buf = bufio.NewReadWriter(br, bw) if tlsConn, ok := rwc.(*tls.Conn); ok { + tlsConn.Handshake() c.tlsState = new(tls.ConnectionState) *c.tlsState = tlsConn.ConnectionState() } @@ -254,7 +255,7 @@ func (w *response) WriteHeader(code int) { } } - if w.header.Get("Date") == "" { + if _, ok := w.header["Date"]; !ok { w.Header().Set("Date", time.UTC().Format(TimeFormat)) } @@ -314,6 +315,10 @@ func (w *response) WriteHeader(code int) { w.closeAfterReply = true } + if w.header.Get("Connection") == "close" { + w.closeAfterReply = true + } + // Cannot use Content-Length with non-identity Transfer-Encoding. if w.chunking { w.header.Del("Content-Length") @@ -405,7 +410,7 @@ func errorKludge(w *response) { // Is it a broken browser? var msg string - switch agent := w.req.UserAgent; { + switch agent := w.req.UserAgent(); { case strings.Contains(agent, "MSIE"): msg = "Internet Explorer" case strings.Contains(agent, "Chrome/"): @@ -416,7 +421,7 @@ func errorKludge(w *response) { msg += " would ignore this error page if this text weren't here.\n" // Is it text? ("Content-Type" is always in the map) - baseType := strings.Split(w.header.Get("Content-Type"), ";", 2)[0] + baseType := strings.SplitN(w.header.Get("Content-Type"), ";", 2)[0] switch baseType { case "text/html": io.WriteString(w, "<!-- ") @@ -490,23 +495,9 @@ func (c *conn) serve() { } c.rwc.Close() - // TODO(rsc,bradfitz): this is boilerplate. move it to runtime.Stack() var buf bytes.Buffer fmt.Fprintf(&buf, "http: panic serving %v: %v\n", c.remoteAddr, err) - for i := 1; i < 20; i++ { - pc, file, line, ok := runtime.Caller(i) - if !ok { - break - } - var name string - f := runtime.FuncForPC(pc) - if f != nil { - name = f.Name() - } else { - name = fmt.Sprintf("%#x", pc) - } - fmt.Fprintf(&buf, " %s %s:%d\n", name, file, line) - } + buf.Write(debug.Stack()) log.Print(buf.String()) }() @@ -584,7 +575,7 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err os.Error) // Handler object that calls f. type HandlerFunc func(ResponseWriter, *Request) -// ServeHTTP calls f(w, req). +// ServeHTTP calls f(w, r). func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { f(w, r) } @@ -605,6 +596,22 @@ func NotFound(w ResponseWriter, r *Request) { Error(w, "404 page not found", Sta // that replies to each request with a ``404 page not found'' reply. func NotFoundHandler() Handler { return HandlerFunc(NotFound) } +// StripPrefix returns a handler that serves HTTP requests +// by removing the given prefix from the request URL's Path +// and invoking the handler h. StripPrefix handles a +// request for a path that doesn't begin with prefix by +// replying with an HTTP 404 not found error. +func StripPrefix(prefix string, h Handler) Handler { + return HandlerFunc(func(w ResponseWriter, r *Request) { + if !strings.HasPrefix(r.URL.Path, prefix) { + NotFound(w, r) + return + } + r.URL.Path = r.URL.Path[len(prefix):] + h.ServeHTTP(w, r) + }) +} + // Redirect replies to the request with a redirect to url, // which may be a path relative to the request path. func Redirect(w ResponseWriter, r *Request, url string, code int) { @@ -922,7 +929,9 @@ func ListenAndServe(addr string, handler Handler) os.Error { // ListenAndServeTLS acts identically to ListenAndServe, except that it // expects HTTPS connections. Additionally, files containing a certificate and -// matching private key for the server must be provided. +// matching private key for the server must be provided. If the certificate +// is signed by a certificate authority, the certFile should be the concatenation +// of the server's certificate followed by the CA's certificate. // // A trivial example server is: // @@ -947,6 +956,24 @@ func ListenAndServe(addr string, handler Handler) os.Error { // // One can use generate_cert.go in crypto/tls to generate cert.pem and key.pem. func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Handler) os.Error { + server := &Server{Addr: addr, Handler: handler} + return server.ListenAndServeTLS(certFile, keyFile) +} + +// ListenAndServeTLS listens on the TCP network address srv.Addr and +// then calls Serve to handle requests on incoming TLS connections. +// +// Filenames containing a certificate and matching private key for +// the server must be provided. If the certificate is signed by a +// certificate authority, the certFile should be the concatenation +// of the server's certificate followed by the CA's certificate. +// +// If srv.Addr is blank, ":https" is used. +func (s *Server) ListenAndServeTLS(certFile, keyFile string) os.Error { + addr := s.Addr + if addr == "" { + addr = ":https" + } config := &tls.Config{ Rand: rand.Reader, Time: time.Seconds, @@ -966,7 +993,7 @@ func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Han } tlsListener := tls.NewListener(conn, config) - return Serve(tlsListener, handler) + return s.Serve(tlsListener) } // TimeoutHandler returns a Handler that runs h with the given time limit. diff --git a/src/pkg/http/spdy/read.go b/src/pkg/http/spdy/read.go index 159dbc578..c6b6ab3af 100644 --- a/src/pkg/http/spdy/read.go +++ b/src/pkg/http/spdy/read.go @@ -80,7 +80,7 @@ func (frame *HeadersFrame) read(h ControlFrameHeader, f *Framer) os.Error { func newControlFrame(frameType ControlFrameType) (controlFrame, os.Error) { ctor, ok := cframeCtor[frameType] if !ok { - return nil, InvalidControlFrame + return nil, &Error{Err: InvalidControlFrame} } return ctor(), nil } @@ -97,30 +97,12 @@ var cframeCtor = map[ControlFrameType]func() controlFrame{ // TODO(willchan): Add TypeWindowUpdate } -type corkedReader struct { - r io.Reader - ch chan int - n int -} - -func (cr *corkedReader) Read(p []byte) (int, os.Error) { - if cr.n == 0 { - cr.n = <-cr.ch - } - if len(p) > cr.n { - p = p[:cr.n] - } - n, err := cr.r.Read(p) - cr.n -= n - return n, err -} - -func (f *Framer) uncorkHeaderDecompressor(payloadSize int) os.Error { +func (f *Framer) uncorkHeaderDecompressor(payloadSize int64) os.Error { if f.headerDecompressor != nil { - f.headerReader.ch <- payloadSize + f.headerReader.N = payloadSize return nil } - f.headerReader = corkedReader{r: f.r, ch: make(chan int, 1), n: payloadSize} + f.headerReader = io.LimitedReader{R: f.r, N: payloadSize} decompressor, err := zlib.NewReaderDict(&f.headerReader, []byte(HeaderDictionary)) if err != nil { return err @@ -161,11 +143,12 @@ func (f *Framer) parseControlFrame(version uint16, frameType ControlFrameType) ( return cframe, nil } -func parseHeaderValueBlock(r io.Reader) (http.Header, os.Error) { +func parseHeaderValueBlock(r io.Reader, streamId uint32) (http.Header, os.Error) { var numHeaders uint16 if err := binary.Read(r, binary.BigEndian, &numHeaders); err != nil { return nil, err } + var e os.Error h := make(http.Header, int(numHeaders)) for i := 0; i < int(numHeaders); i++ { var length uint16 @@ -178,10 +161,11 @@ func parseHeaderValueBlock(r io.Reader) (http.Header, os.Error) { } name := string(nameBytes) if name != strings.ToLower(name) { - return nil, UnlowercasedHeaderName + e = &Error{UnlowercasedHeaderName, streamId} + name = strings.ToLower(name) } if h[name] != nil { - return nil, DuplicateHeaders + e = &Error{DuplicateHeaders, streamId} } if err := binary.Read(r, binary.BigEndian, &length); err != nil { return nil, err @@ -190,11 +174,14 @@ func parseHeaderValueBlock(r io.Reader) (http.Header, os.Error) { if _, err := io.ReadFull(r, value); err != nil { return nil, err } - valueList := strings.Split(string(value), "\x00", -1) + valueList := strings.Split(string(value), "\x00") for _, v := range valueList { h.Add(name, v) } } + if e != nil { + return h, e + } return h, nil } @@ -214,14 +201,25 @@ func (f *Framer) readSynStreamFrame(h ControlFrameHeader, frame *SynStreamFrame) reader := f.r if !f.headerCompressionDisabled { - f.uncorkHeaderDecompressor(int(h.length - 10)) + f.uncorkHeaderDecompressor(int64(h.length - 10)) reader = f.headerDecompressor } - frame.Headers, err = parseHeaderValueBlock(reader) + frame.Headers, err = parseHeaderValueBlock(reader, frame.StreamId) + if !f.headerCompressionDisabled && ((err == os.EOF && f.headerReader.N == 0) || f.headerReader.N != 0) { + err = &Error{WrongCompressedPayloadSize, 0} + } if err != nil { return err } + // Remove this condition when we bump Version to 3. + if Version >= 3 { + for h, _ := range frame.Headers { + if invalidReqHeaders[h] { + return &Error{InvalidHeaderPresent, frame.StreamId} + } + } + } return nil } @@ -237,13 +235,24 @@ func (f *Framer) readSynReplyFrame(h ControlFrameHeader, frame *SynReplyFrame) o } reader := f.r if !f.headerCompressionDisabled { - f.uncorkHeaderDecompressor(int(h.length - 6)) + f.uncorkHeaderDecompressor(int64(h.length - 6)) reader = f.headerDecompressor } - frame.Headers, err = parseHeaderValueBlock(reader) + frame.Headers, err = parseHeaderValueBlock(reader, frame.StreamId) + if !f.headerCompressionDisabled && ((err == os.EOF && f.headerReader.N == 0) || f.headerReader.N != 0) { + err = &Error{WrongCompressedPayloadSize, 0} + } if err != nil { return err } + // Remove this condition when we bump Version to 3. + if Version >= 3 { + for h, _ := range frame.Headers { + if invalidRespHeaders[h] { + return &Error{InvalidHeaderPresent, frame.StreamId} + } + } + } return nil } @@ -259,13 +268,31 @@ func (f *Framer) readHeadersFrame(h ControlFrameHeader, frame *HeadersFrame) os. } reader := f.r if !f.headerCompressionDisabled { - f.uncorkHeaderDecompressor(int(h.length - 6)) + f.uncorkHeaderDecompressor(int64(h.length - 6)) reader = f.headerDecompressor } - frame.Headers, err = parseHeaderValueBlock(reader) + frame.Headers, err = parseHeaderValueBlock(reader, frame.StreamId) + if !f.headerCompressionDisabled && ((err == os.EOF && f.headerReader.N == 0) || f.headerReader.N != 0) { + err = &Error{WrongCompressedPayloadSize, 0} + } if err != nil { return err } + + // Remove this condition when we bump Version to 3. + if Version >= 3 { + var invalidHeaders map[string]bool + if frame.StreamId%2 == 0 { + invalidHeaders = invalidReqHeaders + } else { + invalidHeaders = invalidRespHeaders + } + for h, _ := range frame.Headers { + if invalidHeaders[h] { + return &Error{InvalidHeaderPresent, frame.StreamId} + } + } + } return nil } @@ -279,7 +306,6 @@ func (f *Framer) parseDataFrame(streamId uint32) (*DataFrame, os.Error) { frame.Flags = DataFlags(length >> 24) length &= 0xffffff frame.Data = make([]byte, length) - // TODO(willchan): Support compressed data frames. if _, err := io.ReadFull(f.r, frame.Data); err != nil { return nil, err } diff --git a/src/pkg/http/spdy/spdy_test.go b/src/pkg/http/spdy/spdy_test.go index 9100e1ea8..cb91e0286 100644 --- a/src/pkg/http/spdy/spdy_test.go +++ b/src/pkg/http/spdy/spdy_test.go @@ -21,7 +21,8 @@ func TestHeaderParsing(t *testing.T) { var headerValueBlockBuf bytes.Buffer writeHeaderValueBlock(&headerValueBlockBuf, headers) - newHeaders, err := parseHeaderValueBlock(&headerValueBlockBuf) + const bogusStreamId = 1 + newHeaders, err := parseHeaderValueBlock(&headerValueBlockBuf, bogusStreamId) if err != nil { t.Fatal("parseHeaderValueBlock:", err) } diff --git a/src/pkg/http/spdy/types.go b/src/pkg/http/spdy/types.go index 5a665f04f..41cafb174 100644 --- a/src/pkg/http/spdy/types.go +++ b/src/pkg/http/spdy/types.go @@ -10,7 +10,6 @@ import ( "http" "io" "os" - "strconv" ) // Data Frame Format @@ -302,33 +301,41 @@ const HeaderDictionary = "optionsgetheadpostputdeletetrace" + "chunkedtext/htmlimage/pngimage/jpgimage/gifapplication/xmlapplication/xhtmltext/plainpublicmax-age" + "charset=iso-8859-1utf-8gzipdeflateHTTP/1.1statusversionurl\x00" -type FramerError int +// A SPDY specific error. +type ErrorCode string const ( - Internal FramerError = iota - InvalidControlFrame - UnlowercasedHeaderName - DuplicateHeaders - UnknownFrameType - InvalidDataFrame + UnlowercasedHeaderName ErrorCode = "header was not lowercased" + DuplicateHeaders ErrorCode = "multiple headers with same name" + WrongCompressedPayloadSize ErrorCode = "compressed payload size was incorrect" + UnknownFrameType ErrorCode = "unknown frame type" + InvalidControlFrame ErrorCode = "invalid control frame" + InvalidDataFrame ErrorCode = "invalid data frame" + InvalidHeaderPresent ErrorCode = "frame contained invalid header" ) -func (e FramerError) String() string { - switch e { - case Internal: - return "Internal" - case InvalidControlFrame: - return "InvalidControlFrame" - case UnlowercasedHeaderName: - return "UnlowercasedHeaderName" - case DuplicateHeaders: - return "DuplicateHeaders" - case UnknownFrameType: - return "UnknownFrameType" - case InvalidDataFrame: - return "InvalidDataFrame" - } - return "Error(" + strconv.Itoa(int(e)) + ")" +// Error contains both the type of error and additional values. StreamId is 0 +// if Error is not associated with a stream. +type Error struct { + Err ErrorCode + StreamId uint32 +} + +func (e *Error) String() string { + return string(e.Err) +} + +var invalidReqHeaders = map[string]bool{ + "Connection": true, + "Keep-Alive": true, + "Proxy-Connection": true, + "Transfer-Encoding": true, +} + +var invalidRespHeaders = map[string]bool{ + "Connection": true, + "Keep-Alive": true, + "Transfer-Encoding": true, } // Framer handles serializing/deserializing SPDY frames, including compressing/ @@ -339,7 +346,7 @@ type Framer struct { headerBuf *bytes.Buffer headerCompressor *zlib.Writer r io.Reader - headerReader corkedReader + headerReader io.LimitedReader headerDecompressor io.ReadCloser } diff --git a/src/pkg/http/spdy/write.go b/src/pkg/http/spdy/write.go index aa1679f1b..7d40bbe9f 100644 --- a/src/pkg/http/spdy/write.go +++ b/src/pkg/http/spdy/write.go @@ -267,10 +267,9 @@ func (f *Framer) writeHeadersFrame(frame *HeadersFrame) (err os.Error) { func (f *Framer) writeDataFrame(frame *DataFrame) (err os.Error) { // Validate DataFrame if frame.StreamId&0x80000000 != 0 || len(frame.Data) >= 0x0f000000 { - return InvalidDataFrame + return &Error{InvalidDataFrame, frame.StreamId} } - // TODO(willchan): Support data compression. // Serialize frame to Writer if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { return diff --git a/src/pkg/http/transfer.go b/src/pkg/http/transfer.go index b54508e7a..b65d99a6f 100644 --- a/src/pkg/http/transfer.go +++ b/src/pkg/http/transfer.go @@ -5,6 +5,7 @@ package http import ( + "bytes" "bufio" "io" "io/ioutil" @@ -17,7 +18,8 @@ import ( // sanitizes them without changing the user object and provides methods for // writing the respective header, body and trailer in wire format. type transferWriter struct { - Body io.ReadCloser + Body io.Reader + BodyCloser io.Closer ResponseToHEAD bool ContentLength int64 Close bool @@ -33,16 +35,37 @@ func newTransferWriter(r interface{}) (t *transferWriter, err os.Error) { switch rr := r.(type) { case *Request: t.Body = rr.Body + t.BodyCloser = rr.Body t.ContentLength = rr.ContentLength t.Close = rr.Close t.TransferEncoding = rr.TransferEncoding t.Trailer = rr.Trailer atLeastHTTP11 = rr.ProtoAtLeast(1, 1) - if t.Body != nil && t.ContentLength <= 0 && len(t.TransferEncoding) == 0 && atLeastHTTP11 { - t.TransferEncoding = []string{"chunked"} + if t.Body != nil && len(t.TransferEncoding) == 0 && atLeastHTTP11 { + if t.ContentLength == 0 { + // Test to see if it's actually zero or just unset. + var buf [1]byte + n, _ := io.ReadFull(t.Body, buf[:]) + if n == 1 { + // Oh, guess there is data in this Body Reader after all. + // The ContentLength field just wasn't set. + // Stich the Body back together again, re-attaching our + // consumed byte. + t.ContentLength = -1 + t.Body = io.MultiReader(bytes.NewBuffer(buf[:]), t.Body) + } else { + // Body is actually empty. + t.Body = nil + t.BodyCloser = nil + } + } + if t.ContentLength < 0 { + t.TransferEncoding = []string{"chunked"} + } } case *Response: t.Body = rr.Body + t.BodyCloser = rr.Body t.ContentLength = rr.ContentLength t.Close = rr.Close t.TransferEncoding = rr.TransferEncoding @@ -147,7 +170,7 @@ func (t *transferWriter) WriteBody(w io.Writer) (err os.Error) { if err != nil { return err } - if err = t.Body.Close(); err != nil { + if err = t.BodyCloser.Close(); err != nil { return err } } @@ -195,6 +218,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err os.Error) { t := &transferReader{} // Unify input + isResponse := false switch rr := msg.(type) { case *Response: t.Header = rr.Header @@ -203,6 +227,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err os.Error) { t.ProtoMajor = rr.ProtoMajor t.ProtoMinor = rr.ProtoMinor t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header) + isResponse = true case *Request: t.Header = rr.Header t.ProtoMajor = rr.ProtoMajor @@ -211,6 +236,8 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err os.Error) { // Responses with status code 200, responding to a GET method t.StatusCode = 200 t.RequestMethod = "GET" + default: + panic("unexpected type") } // Default to HTTP/1.1 @@ -224,7 +251,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err os.Error) { return err } - t.ContentLength, err = fixLength(t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding) + t.ContentLength, err = fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding) if err != nil { return err } @@ -252,7 +279,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err os.Error) { // or close connection when finished, since multipart is not supported yet switch { case chunked(t.TransferEncoding): - t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} + t.Body = &body{Reader: NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} case t.ContentLength >= 0: // TODO: limit the Content-Length. This is an easy DoS vector. t.Body = &body{Reader: io.LimitReader(r, t.ContentLength), closing: t.Close} @@ -265,9 +292,6 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err os.Error) { // Persistent connection (i.e. HTTP/1.1) t.Body = &body{Reader: io.LimitReader(r, 0), closing: t.Close} } - // TODO(petar): It may be a good idea, for extra robustness, to - // assume ContentLength=0 for GET requests (and other special - // cases?). This logic should be in fixLength(). } // Unify output @@ -310,7 +334,7 @@ func fixTransferEncoding(requestMethod string, header Header) ([]string, os.Erro return nil, nil } - encodings := strings.Split(raw[0], ",", -1) + encodings := strings.Split(raw[0], ",") te := make([]string, 0, len(encodings)) // TODO: Even though we only support "identity" and "chunked" // encodings, the loop below is designed with foresight. One @@ -345,7 +369,7 @@ func fixTransferEncoding(requestMethod string, header Header) ([]string, os.Erro // Determine the expected body length, using RFC 2616 Section 4.4. This // function is not a method, because ultimately it should be shared by // ReadResponse and ReadRequest. -func fixLength(status int, requestMethod string, header Header, te []string) (int64, os.Error) { +func fixLength(isResponse bool, status int, requestMethod string, header Header, te []string) (int64, os.Error) { // Logic based on response type or status if noBodyExpected(requestMethod) { @@ -376,6 +400,14 @@ func fixLength(status int, requestMethod string, header Header, te []string) (in header.Del("Content-Length") } + if !isResponse && requestMethod == "GET" { + // RFC 2616 doesn't explicitly permit nor forbid an + // entity-body on a GET request so we permit one if + // declared, but we default to 0 here (not -1 below) + // if there's no mention of a body. + return 0, nil + } + // Logic based on media type. The purpose of the following code is just // to detect whether the unsupported "multipart/byteranges" is being // used. A proper Content-Type parser is needed in the future. @@ -418,7 +450,7 @@ func fixTrailer(header Header, te []string) (Header, os.Error) { header.Del("Trailer") trailer := make(Header) - keys := strings.Split(raw, ",", -1) + keys := strings.Split(raw, ",") for _, key := range keys { key = CanonicalHeaderKey(strings.TrimSpace(key)) switch key { diff --git a/src/pkg/http/transport.go b/src/pkg/http/transport.go index c907d85fd..3c16c880d 100644 --- a/src/pkg/http/transport.go +++ b/src/pkg/http/transport.go @@ -76,12 +76,12 @@ func ProxyFromEnvironment(req *Request) (*URL, os.Error) { } proxyURL, err := ParseRequestURL(proxy) if err != nil { - return nil, os.ErrorString("invalid proxy address") + return nil, os.NewError("invalid proxy address") } if proxyURL.Host == "" { proxyURL, err = ParseRequestURL("http://" + proxy) if err != nil { - return nil, os.ErrorString("invalid proxy address") + return nil, os.NewError("invalid proxy address") } } return proxyURL, nil @@ -329,9 +329,9 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { return nil, err } if resp.StatusCode != 200 { - f := strings.Split(resp.Status, " ", 2) + f := strings.SplitN(resp.Status, " ", 2) conn.Close() - return nil, os.ErrorString(f[1]) + return nil, os.NewError(f[1]) } } @@ -383,7 +383,7 @@ func useProxy(addr string) bool { addr = addr[:strings.LastIndex(addr, ":")] } - for _, p := range strings.Split(no_proxy, ",", -1) { + for _, p := range strings.Split(no_proxy, ",") { p = strings.ToLower(strings.TrimSpace(p)) if len(p) == 0 { continue diff --git a/src/pkg/http/url.go b/src/pkg/http/url.go index 05b1662d3..e934b27c4 100644 --- a/src/pkg/http/url.go +++ b/src/pkg/http/url.go @@ -299,7 +299,7 @@ func getscheme(rawurl string) (scheme, path string, err os.Error) { } case c == ':': if i == 0 { - return "", "", os.ErrorString("missing protocol scheme") + return "", "", os.NewError("missing protocol scheme") } return rawurl[0:i], rawurl[i+1:], nil default: @@ -348,8 +348,13 @@ func ParseRequestURL(rawurl string) (url *URL, err os.Error) { // in which case only absolute URLs or path-absolute relative URLs are allowed. // If viaRequest is false, all forms of relative URLs are allowed. func parseURL(rawurl string, viaRequest bool) (url *URL, err os.Error) { + var ( + leadingSlash bool + path string + ) + if rawurl == "" { - err = os.ErrorString("empty url") + err = os.NewError("empty url") goto Error } url = new(URL) @@ -357,12 +362,10 @@ func parseURL(rawurl string, viaRequest bool) (url *URL, err os.Error) { // Split off possible leading "http:", "mailto:", etc. // Cannot contain escaped characters. - var path string if url.Scheme, path, err = getscheme(rawurl); err != nil { goto Error } - - leadingSlash := strings.HasPrefix(path, "/") + leadingSlash = strings.HasPrefix(path, "/") if url.Scheme != "" && !leadingSlash { // RFC 2396: @@ -377,7 +380,7 @@ func parseURL(rawurl string, viaRequest bool) (url *URL, err os.Error) { url.OpaquePath = true } else { if viaRequest && !leadingSlash { - err = os.ErrorString("invalid URI for request") + err = os.NewError("invalid URI for request") goto Error } @@ -411,7 +414,7 @@ func parseURL(rawurl string, viaRequest bool) (url *URL, err os.Error) { if strings.Contains(rawHost, "%") { // Host cannot contain escaped characters. - err = os.ErrorString("hexadecimal escape in host") + err = os.NewError("hexadecimal escape in host") goto Error } url.Host = rawHost @@ -505,8 +508,8 @@ func (v Values) Encode() string { // resolvePath applies special path segments from refs and applies // them to base, per RFC 2396. func resolvePath(basepath string, refpath string) string { - base := strings.Split(basepath, "/", -1) - refs := strings.Split(refpath, "/", -1) + base := strings.Split(basepath, "/") + refs := strings.Split(refpath, "/") if len(base) == 0 { base = []string{""} } |