diff options
Diffstat (limited to 'src/pkg/http/serve_test.go')
-rw-r--r-- | src/pkg/http/serve_test.go | 273 |
1 files changed, 262 insertions, 11 deletions
diff --git a/src/pkg/http/serve_test.go b/src/pkg/http/serve_test.go index 86d64bdbb..0142dead9 100644 --- a/src/pkg/http/serve_test.go +++ b/src/pkg/http/serve_test.go @@ -15,6 +15,7 @@ import ( "io/ioutil" "os" "net" + "reflect" "strings" "testing" "time" @@ -144,7 +145,7 @@ func TestConsumingBodyOnNextConn(t *testing.T) { type stringHandler string func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) { - w.SetHeader("Result", string(s)) + w.Header().Set("Result", string(s)) } var handlers = []struct { @@ -174,7 +175,7 @@ func TestHostHandlers(t *testing.T) { ts := httptest.NewServer(nil) defer ts.Close() - conn, err := net.Dial("tcp", "", ts.Listener.Addr().String()) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) } @@ -216,7 +217,7 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { mux.ServeHTTP(resp, req) - if loc, expected := resp.Header.Get("Location"), "/foo.txt"; loc != expected { + if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected { t.Errorf("Expected Location header set to %q; got %q", expected, loc) return } @@ -229,7 +230,8 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { } func TestServerTimeouts(t *testing.T) { - l, err := net.ListenTCP("tcp", &net.TCPAddr{Port: 0}) + // TODO(bradfitz): convert this to use httptest.Server + l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen error: %v", err) } @@ -248,7 +250,9 @@ func TestServerTimeouts(t *testing.T) { url := fmt.Sprintf("http://localhost:%d/", addr.Port) // Hit the HTTP server successfully. - r, _, err := Get(url) + tr := &Transport{DisableKeepAlives: true} // they interfere with this test + c := &Client{Transport: tr} + r, _, err := c.Get(url) if err != nil { t.Fatalf("http Get #1: %v", err) } @@ -261,7 +265,7 @@ func TestServerTimeouts(t *testing.T) { // Slow client that should timeout. t1 := time.Nanoseconds() - conn, err := net.Dial("tcp", "", fmt.Sprintf("localhost:%d", addr.Port)) + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", addr.Port)) if err != nil { t.Fatalf("Dial: %v", err) } @@ -294,8 +298,8 @@ func TestServerTimeouts(t *testing.T) { // TestIdentityResponse verifies that a handler can unset func TestIdentityResponse(t *testing.T) { handler := HandlerFunc(func(rw ResponseWriter, req *Request) { - rw.SetHeader("Content-Length", "3") - rw.SetHeader("Transfer-Encoding", req.FormValue("te")) + rw.Header().Set("Content-Length", "3") + rw.Header().Set("Transfer-Encoding", req.FormValue("te")) switch { case req.FormValue("overwrite") == "1": _, err := rw.Write([]byte("foo TOO LONG")) @@ -303,7 +307,7 @@ func TestIdentityResponse(t *testing.T) { t.Errorf("expected ErrContentLength; got %v", err) } case req.FormValue("underwrite") == "1": - rw.SetHeader("Content-Length", "500") + rw.Header().Set("Content-Length", "500") rw.Write([]byte("too short")) default: rw.Write([]byte("foo")) @@ -333,6 +337,7 @@ func TestIdentityResponse(t *testing.T) { t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)", url, expected, tl, res.TransferEncoding) } + res.Body.Close() } // Verify that ErrContentLength is returned @@ -341,10 +346,9 @@ func TestIdentityResponse(t *testing.T) { if err != nil { t.Fatalf("error with Get of %s: %v", url, err) } - // Verify that the connection is closed when the declared Content-Length // is larger than what the handler wrote. - conn, err := net.Dial("tcp", "", ts.Listener.Addr().String()) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("error dialing: %v", err) } @@ -365,3 +369,250 @@ func TestIdentityResponse(t *testing.T) { expectedSuffix, string(got)) } } + +// TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive. +func TestServeHTTP10Close(t *testing.T) { + s := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "testdata/file") + })) + defer s.Close() + + conn, err := net.Dial("tcp", s.Listener.Addr().String()) + if err != nil { + t.Fatal("dial error:", err) + } + defer conn.Close() + + _, err = fmt.Fprint(conn, "GET / HTTP/1.0\r\n\r\n") + if err != nil { + t.Fatal("print error:", err) + } + + r := bufio.NewReader(conn) + _, err = ReadResponse(r, "GET") + if err != nil { + t.Fatal("ReadResponse error:", err) + } + + success := make(chan bool) + go func() { + select { + case <-time.After(5e9): + t.Fatal("body not closed after 5s") + case <-success: + } + }() + + _, err = ioutil.ReadAll(r) + if err != nil { + t.Fatal("read error:", err) + } + + success <- true +} + +func TestSetsRemoteAddr(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%s", r.RemoteAddr) + })) + defer ts.Close() + + res, _, err := Get(ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + ip := string(body) + if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") { + t.Fatalf("Expected local addr; got %q", ip) + } +} + +func TestChunkedResponseHeaders(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted + fmt.Fprintf(w, "I am a chunked response.") + })) + defer ts.Close() + + res, _, err := Get(ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + if g, e := res.ContentLength, int64(-1); g != e { + t.Errorf("expected ContentLength of %d; got %d", e, g) + } + if g, e := res.TransferEncoding, []string{"chunked"}; !reflect.DeepEqual(g, e) { + t.Errorf("expected TransferEncoding of %v; got %v", e, g) + } + if _, haveCL := res.Header["Content-Length"]; haveCL { + t.Errorf("Unexpected Content-Length") + } +} + +// Test304Responses verifies that 304s don't declare that they're +// chunking in their response headers and aren't allowed to produce +// output. +func Test304Responses(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusNotModified) + _, err := w.Write([]byte("illegal body")) + if err != ErrBodyNotAllowed { + t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) + } + })) + defer ts.Close() + res, _, err := Get(ts.URL) + if err != nil { + t.Error(err) + } + if len(res.TransferEncoding) > 0 { + t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + if len(body) > 0 { + t.Errorf("got unexpected body %q", string(body)) + } +} + +// TestHeadResponses verifies that responses to HEAD requests don't +// declare that they're chunking in their response headers and aren't +// allowed to produce output. +func TestHeadResponses(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := w.Write([]byte("Ignored body")) + if err != ErrBodyNotAllowed { + t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) + } + })) + defer ts.Close() + res, err := Head(ts.URL) + if err != nil { + t.Error(err) + } + if len(res.TransferEncoding) > 0 { + t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + if len(body) > 0 { + t.Errorf("got unexpected body %q", string(body)) + } +} + +func TestTLSServer(t *testing.T) { + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "tls=%v", r.TLS != nil) + })) + defer ts.Close() + if !strings.HasPrefix(ts.URL, "https://") { + t.Fatalf("expected test TLS server to start with https://, got %q", ts.URL) + } + res, _, err := Get(ts.URL) + if err != nil { + t.Error(err) + } + if res == nil { + t.Fatalf("got nil Response") + } + if res.Body == nil { + t.Fatalf("got nil Response.Body") + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + if e, g := "tls=true", string(body); e != g { + t.Errorf("expected body %q; got %q", e, g) + } +} + +type serverExpectTest struct { + contentLength int // of request body + expectation string // e.g. "100-continue" + readBody bool // whether handler should read the body (if false, sends StatusUnauthorized) + expectedResponse string // expected substring in first line of http response +} + +var serverExpectTests = []serverExpectTest{ + // Normal 100-continues, case-insensitive. + {100, "100-continue", true, "100 Continue"}, + {100, "100-cOntInUE", true, "100 Continue"}, + + // No 100-continue. + {100, "", true, "200 OK"}, + + // 100-continue but requesting client to deny us, + // so it never eads the body. + {100, "100-continue", false, "401 Unauthorized"}, + // Likewise without 100-continue: + {100, "", false, "401 Unauthorized"}, + + // Non-standard expectations are failures + {0, "a-pony", false, "417 Expectation Failed"}, + + // Expect-100 requested but no body + {0, "100-continue", true, "400 Bad Request"}, +} + +// Tests that the server responds to the "Expect" request header +// correctly. +func TestServerExpect(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + // Note using r.FormValue("readbody") because for POST + // requests that would read from r.Body, which we only + // conditionally want to do. + if strings.Contains(r.URL.RawPath, "readbody=true") { + ioutil.ReadAll(r.Body) + w.Write([]byte("Hi")) + } else { + w.WriteHeader(StatusUnauthorized) + } + })) + defer ts.Close() + + runTest := func(test serverExpectTest) { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + sendf := func(format string, args ...interface{}) { + _, err := fmt.Fprintf(conn, format, args...) + if err != nil { + t.Fatalf("Error writing %q: %v", format, err) + } + } + go func() { + sendf("POST /?readbody=%v HTTP/1.1\r\n"+ + "Connection: close\r\n"+ + "Content-Length: %d\r\n"+ + "Expect: %s\r\nHost: foo\r\n\r\n", + test.readBody, test.contentLength, test.expectation) + if test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" { + body := strings.Repeat("A", test.contentLength) + sendf(body) + } + }() + bufr := bufio.NewReader(conn) + line, err := bufr.ReadString('\n') + if err != nil { + t.Fatalf("ReadString: %v", err) + } + if !strings.Contains(line, test.expectedResponse) { + t.Errorf("for test %#v got first line=%q", test, line) + } + } + + for _, test := range serverExpectTests { + runTest(test) + } +} |