diff options
Diffstat (limited to 'src/pkg/net/http/client_test.go')
-rw-r--r-- | src/pkg/net/http/client_test.go | 250 |
1 files changed, 243 insertions, 7 deletions
diff --git a/src/pkg/net/http/client_test.go b/src/pkg/net/http/client_test.go index 9b4261b9f..88649bb16 100644 --- a/src/pkg/net/http/client_test.go +++ b/src/pkg/net/http/client_test.go @@ -7,7 +7,9 @@ package http_test import ( + "bytes" "crypto/tls" + "crypto/x509" "errors" "fmt" "io" @@ -53,6 +55,7 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) { } func TestClient(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() @@ -70,6 +73,7 @@ func TestClient(t *testing.T) { } func TestClientHead(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() @@ -92,6 +96,7 @@ func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) } func TestGetRequestFormat(t *testing.T) { + defer checkLeakedTransports(t) tr := &recordingTransport{} client := &Client{Transport: tr} url := "http://dummy.faketld/" @@ -108,6 +113,7 @@ func TestGetRequestFormat(t *testing.T) { } func TestPostRequestFormat(t *testing.T) { + defer checkLeakedTransports(t) tr := &recordingTransport{} client := &Client{Transport: tr} @@ -134,6 +140,7 @@ func TestPostRequestFormat(t *testing.T) { } func TestPostFormRequestFormat(t *testing.T) { + defer checkLeakedTransports(t) tr := &recordingTransport{} client := &Client{Transport: tr} @@ -175,6 +182,7 @@ func TestPostFormRequestFormat(t *testing.T) { } func TestRedirects(t *testing.T) { + defer checkLeakedTransports(t) var ts *httptest.Server ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { n, _ := strconv.Atoi(r.FormValue("n")) @@ -218,6 +226,10 @@ func TestRedirects(t *testing.T) { return checkErr }} res, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + res.Body.Close() finalUrl := res.Request.URL.String() if e, g := "<nil>", fmt.Sprintf("%v", err); e != g { t.Errorf("with custom client, expected error %q, got %q", e, g) @@ -231,9 +243,63 @@ func TestRedirects(t *testing.T) { checkErr = errors.New("no redirects allowed") res, err = c.Get(ts.URL) - finalUrl = res.Request.URL.String() - if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g { - t.Errorf("with redirects forbidden, expected error %q, got %q", e, g) + if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr { + t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err) + } + if res == nil { + t.Fatalf("Expected a non-nil Response on CheckRedirect failure (http://golang.org/issue/3795)") + } + res.Body.Close() + if res.Header.Get("Location") == "" { + t.Errorf("no Location header in Response") + } +} + +func TestPostRedirects(t *testing.T) { + defer checkLeakedTransports(t) + var log struct { + sync.Mutex + bytes.Buffer + } + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + log.Lock() + fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI) + log.Unlock() + if v := r.URL.Query().Get("code"); v != "" { + code, _ := strconv.Atoi(v) + if code/100 == 3 { + w.Header().Set("Location", ts.URL) + } + w.WriteHeader(code) + } + })) + defer ts.Close() + tests := []struct { + suffix string + want int // response code + }{ + {"/", 200}, + {"/?code=301", 301}, + {"/?code=302", 200}, + {"/?code=303", 200}, + {"/?code=404", 404}, + } + for _, tt := range tests { + res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content")) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != tt.want { + t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want) + } + } + log.Lock() + got := log.String() + log.Unlock() + want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 " + if got != want { + t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want) } } @@ -279,6 +345,10 @@ func TestClientSendsCookieFromJar(t *testing.T) { req, _ := NewRequest("GET", us, nil) client.Do(req) // Note: doesn't hit network matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) + + req, _ = NewRequest("POST", us, nil) + client.Do(req) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) } // Just enough correctness for our redirect tests. Uses the URL.Host as the @@ -291,6 +361,9 @@ type TestJar struct { func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) { j.m.Lock() defer j.m.Unlock() + if j.perURL == nil { + j.perURL = make(map[string][]*Cookie) + } j.perURL[u.Host] = cookies } @@ -301,6 +374,7 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie { } func TestRedirectCookiesOnRequest(t *testing.T) { + defer checkLeakedTransports(t) var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() @@ -318,14 +392,20 @@ func TestRedirectCookiesOnRequest(t *testing.T) { } func TestRedirectCookiesJar(t *testing.T) { + defer checkLeakedTransports(t) var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() - c := &Client{} - c.Jar = &TestJar{perURL: make(map[string][]*Cookie)} + c := &Client{ + Jar: new(TestJar), + } u, _ := url.Parse(ts.URL) c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]}) - resp, _ := c.Get(ts.URL) + resp, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + resp.Body.Close() matchReturnedCookies(t, expectedCookies, resp.Cookies()) } @@ -348,7 +428,72 @@ func matchReturnedCookies(t *testing.T, expected, given []*Cookie) { } } +func TestJarCalls(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + pathSuffix := r.RequestURI[1:] + if r.RequestURI == "/nosetcookie" { + return // dont set cookies for this path + } + SetCookie(w, &Cookie{Name: "name" + pathSuffix, Value: "val" + pathSuffix}) + if r.RequestURI == "/" { + Redirect(w, r, "http://secondhost.fake/secondpath", 302) + } + })) + defer ts.Close() + jar := new(RecordingJar) + c := &Client{ + Jar: jar, + Transport: &Transport{ + Dial: func(_ string, _ string) (net.Conn, error) { + return net.Dial("tcp", ts.Listener.Addr().String()) + }, + }, + } + _, err := c.Get("http://firsthost.fake/") + if err != nil { + t.Fatal(err) + } + _, err = c.Get("http://firsthost.fake/nosetcookie") + if err != nil { + t.Fatal(err) + } + got := jar.log.String() + want := `Cookies("http://firsthost.fake/") +SetCookie("http://firsthost.fake/", [name=val]) +Cookies("http://secondhost.fake/secondpath") +SetCookie("http://secondhost.fake/secondpath", [namesecondpath=valsecondpath]) +Cookies("http://firsthost.fake/nosetcookie") +` + if got != want { + t.Errorf("Got Jar calls:\n%s\nWant:\n%s", got, want) + } +} + +// RecordingJar keeps a log of calls made to it, without +// tracking any cookies. +type RecordingJar struct { + mu sync.Mutex + log bytes.Buffer +} + +func (j *RecordingJar) SetCookies(u *url.URL, cookies []*Cookie) { + j.logf("SetCookie(%q, %v)\n", u, cookies) +} + +func (j *RecordingJar) Cookies(u *url.URL) []*Cookie { + j.logf("Cookies(%q)\n", u) + return nil +} + +func (j *RecordingJar) logf(format string, args ...interface{}) { + j.mu.Lock() + defer j.mu.Unlock() + fmt.Fprintf(&j.log, format, args...) +} + func TestStreamingGet(t *testing.T) { + defer checkLeakedTransports(t) say := make(chan string) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() @@ -399,6 +544,7 @@ func (c *writeCountingConn) Write(p []byte) (int, error) { // TestClientWrites verifies that client requests are buffered and we // don't send a TCP packet per line of the http request + body. func TestClientWrites(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) defer ts.Close() @@ -432,6 +578,7 @@ func TestClientWrites(t *testing.T) { } func TestClientInsecureTransport(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) })) @@ -446,15 +593,20 @@ func TestClientInsecureTransport(t *testing.T) { InsecureSkipVerify: insecure, }, } + defer tr.CloseIdleConnections() c := &Client{Transport: tr} - _, err := c.Get(ts.URL) + res, err := c.Get(ts.URL) if (err == nil) != insecure { t.Errorf("insecure=%v: got unexpected err=%v", insecure, err) } + if res != nil { + res.Body.Close() + } } } func TestClientErrorWithRequestURI(t *testing.T) { + defer checkLeakedTransports(t) req, _ := NewRequest("GET", "http://localhost:1234/", nil) req.RequestURI = "/this/field/is/illegal/and/should/error/" _, err := DefaultClient.Do(req) @@ -465,3 +617,87 @@ func TestClientErrorWithRequestURI(t *testing.T) { t.Errorf("wanted error mentioning RequestURI; got error: %v", err) } } + +func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport { + certs := x509.NewCertPool() + for _, c := range ts.TLS.Certificates { + roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) + if err != nil { + t.Fatalf("error parsing server's root cert: %v", err) + } + for _, root := range roots { + certs.AddCert(root) + } + } + return &Transport{ + TLSClientConfig: &tls.Config{RootCAs: certs}, + } +} + +func TestClientWithCorrectTLSServerName(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.TLS.ServerName != "127.0.0.1" { + t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName) + } + })) + defer ts.Close() + + c := &Client{Transport: newTLSTransport(t, ts)} + if _, err := c.Get(ts.URL); err != nil { + t.Fatalf("expected successful TLS connection, got error: %v", err) + } +} + +func TestClientWithIncorrectTLSServerName(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer ts.Close() + + trans := newTLSTransport(t, ts) + trans.TLSClientConfig.ServerName = "badserver" + c := &Client{Transport: trans} + _, err := c.Get(ts.URL) + if err == nil { + t.Fatalf("expected an error") + } + if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") { + t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err) + } +} + +// Verify Response.ContentLength is populated. http://golang.org/issue/4126 +func TestClientHeadContentLength(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if v := r.FormValue("cl"); v != "" { + w.Header().Set("Content-Length", v) + } + })) + defer ts.Close() + tests := []struct { + suffix string + want int64 + }{ + {"/?cl=1234", 1234}, + {"/?cl=0", 0}, + {"", -1}, + } + for _, tt := range tests { + req, _ := NewRequest("HEAD", ts.URL+tt.suffix, nil) + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if res.ContentLength != tt.want { + t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want) + } + bs, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if len(bs) != 0 { + t.Errorf("Unexpected content: %q", bs) + } + } +} |