summaryrefslogtreecommitdiff
path: root/src/pkg/net/http/transport_test.go
diff options
context:
space:
mode:
authorMichael Stapelberg <stapelberg@debian.org>2013-03-04 21:27:36 +0100
committerMichael Stapelberg <michael@stapelberg.de>2013-03-04 21:27:36 +0100
commit04b08da9af0c450d645ab7389d1467308cfc2db8 (patch)
treedb247935fa4f2f94408edc3acd5d0d4f997aa0d8 /src/pkg/net/http/transport_test.go
parent917c5fb8ec48e22459d77e3849e6d388f93d3260 (diff)
downloadgolang-upstream/1.1_hg20130304.tar.gz
Imported Upstream version 1.1~hg20130304upstream/1.1_hg20130304
Diffstat (limited to 'src/pkg/net/http/transport_test.go')
-rw-r--r--src/pkg/net/http/transport_test.go577
1 files changed, 558 insertions, 19 deletions
diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go
index a9e401de5..68010e68b 100644
--- a/src/pkg/net/http/transport_test.go
+++ b/src/pkg/net/http/transport_test.go
@@ -13,6 +13,7 @@ import (
"fmt"
"io"
"io/ioutil"
+ "net"
. "net/http"
"net/http/httptest"
"net/url"
@@ -20,6 +21,7 @@ import (
"runtime"
"strconv"
"strings"
+ "sync"
"testing"
"time"
)
@@ -35,14 +37,78 @@ var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte(r.RemoteAddr))
})
+// testCloseConn is a net.Conn tracked by a testConnSet.
+type testCloseConn struct {
+ net.Conn
+ set *testConnSet
+}
+
+func (c *testCloseConn) Close() error {
+ c.set.remove(c)
+ return c.Conn.Close()
+}
+
+// testConnSet tracks a set of TCP connections and whether they've
+// been closed.
+type testConnSet struct {
+ t *testing.T
+ closed map[net.Conn]bool
+ list []net.Conn // in order created
+ mutex sync.Mutex
+}
+
+func (tcs *testConnSet) insert(c net.Conn) {
+ tcs.mutex.Lock()
+ defer tcs.mutex.Unlock()
+ tcs.closed[c] = false
+ tcs.list = append(tcs.list, c)
+}
+
+func (tcs *testConnSet) remove(c net.Conn) {
+ tcs.mutex.Lock()
+ defer tcs.mutex.Unlock()
+ tcs.closed[c] = true
+}
+
+// some tests use this to manage raw tcp connections for later inspection
+func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
+ connSet := &testConnSet{
+ t: t,
+ closed: make(map[net.Conn]bool),
+ }
+ dial := func(n, addr string) (net.Conn, error) {
+ c, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ tc := &testCloseConn{c, connSet}
+ connSet.insert(tc)
+ return tc, nil
+ }
+ return connSet, dial
+}
+
+func (tcs *testConnSet) check(t *testing.T) {
+ tcs.mutex.Lock()
+ defer tcs.mutex.Unlock()
+
+ for i, c := range tcs.list {
+ if !tcs.closed[c] {
+ t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
+ }
+ }
+}
+
// Two subsequent requests and verify their response is the same.
// The response from the server is our own IP:port
func TestTransportKeepAlives(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
for _, disableKeepAlive := range []bool{false, true} {
tr := &Transport{DisableKeepAlives: disableKeepAlive}
+ defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
fetch := func(n int) string {
@@ -69,11 +135,16 @@ func TestTransportKeepAlives(t *testing.T) {
}
func TestTransportConnectionCloseOnResponse(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
+ connSet, testDial := makeTestDial(t)
+
for _, connectionClose := range []bool{false, true} {
- tr := &Transport{}
+ tr := &Transport{
+ Dial: testDial,
+ }
c := &Client{Transport: tr}
fetch := func(n int) string {
@@ -92,8 +163,8 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) {
if err != nil {
t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
}
- body, err := ioutil.ReadAll(res.Body)
defer res.Body.Close()
+ body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
}
@@ -107,15 +178,24 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) {
t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
connectionClose, bodiesDiffer, body1, body2)
}
+
+ tr.CloseIdleConnections()
}
+
+ connSet.check(t)
}
func TestTransportConnectionCloseOnRequest(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
+ connSet, testDial := makeTestDial(t)
+
for _, connectionClose := range []bool{false, true} {
- tr := &Transport{}
+ tr := &Transport{
+ Dial: testDial,
+ }
c := &Client{Transport: tr}
fetch := func(n int) string {
@@ -149,10 +229,15 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) {
t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
connectionClose, bodiesDiffer, body1, body2)
}
+
+ tr.CloseIdleConnections()
}
+
+ connSet.check(t)
}
func TestTransportIdleCacheKeys(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
@@ -185,6 +270,7 @@ func TestTransportIdleCacheKeys(t *testing.T) {
}
func TestTransportMaxPerHostIdleConns(t *testing.T) {
+ defer checkLeakedTransports(t)
resch := make(chan string)
gotReq := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -201,7 +287,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
c := &Client{Transport: tr}
// Start 3 outstanding requests and wait for the server to get them.
- // Their responses will hang until we we write to resch, though.
+ // Their responses will hang until we write to resch, though.
donech := make(chan bool)
doReq := func() {
resp, err := c.Get(ts.URL)
@@ -253,6 +339,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
}
func TestTransportServerClosingUnexpectedly(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
@@ -309,9 +396,9 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) {
// Test for http://golang.org/issue/2616 (appropriate issue number)
// This fails pretty reliably with GOMAXPROCS=100 or something high.
func TestStressSurpriseServerCloses(t *testing.T) {
+ defer checkLeakedTransports(t)
if testing.Short() {
- t.Logf("skipping test in short mode")
- return
+ t.Skip("skipping test in short mode")
}
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Length", "5")
@@ -365,6 +452,7 @@ func TestStressSurpriseServerCloses(t *testing.T) {
// TestTransportHeadResponses verifies that we deal with Content-Lengths
// with no bodies properly
func TestTransportHeadResponses(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "HEAD" {
panic("expected HEAD; got " + r.Method)
@@ -384,7 +472,7 @@ func TestTransportHeadResponses(t *testing.T) {
if e, g := "123", res.Header.Get("Content-Length"); e != g {
t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
}
- if e, g := int64(0), res.ContentLength; e != g {
+ if e, g := int64(123), res.ContentLength; e != g {
t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
}
}
@@ -393,6 +481,7 @@ func TestTransportHeadResponses(t *testing.T) {
// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding
// on responses to HEAD requests.
func TestTransportHeadChunkedResponse(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "HEAD" {
panic("expected HEAD; got " + r.Method)
@@ -434,6 +523,7 @@ var roundTripTests = []struct {
// Test that the modification made to the Request by the RoundTripper is cleaned up
func TestRoundTripGzip(t *testing.T) {
+ defer checkLeakedTransports(t)
const responseBody = "test response body"
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
accept := req.Header.Get("Accept-Encoding")
@@ -490,6 +580,7 @@ func TestRoundTripGzip(t *testing.T) {
}
func TestTransportGzip(t *testing.T) {
+ defer checkLeakedTransports(t)
const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
const nRandBytes = 1024 * 1024
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
@@ -582,6 +673,7 @@ func TestTransportGzip(t *testing.T) {
}
func TestTransportProxy(t *testing.T) {
+ defer checkLeakedTransports(t)
ch := make(chan string, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ch <- "real server"
@@ -610,6 +702,7 @@ func TestTransportProxy(t *testing.T) {
// but checks that we don't recurse forever, and checks that
// Content-Encoding is removed.
func TestTransportGzipRecursive(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Encoding", "gzip")
w.Write(rgz)
@@ -636,6 +729,7 @@ func TestTransportGzipRecursive(t *testing.T) {
// tests that persistent goroutine connections shut down when no longer desired.
func TestTransportPersistConnLeak(t *testing.T) {
+ defer checkLeakedTransports(t)
gotReqCh := make(chan bool)
unblockCh := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -698,8 +792,49 @@ func TestTransportPersistConnLeak(t *testing.T) {
}
}
+// golang.org/issue/4531: Transport leaks goroutines when
+// request.ContentLength is explicitly short
+func TestTransportPersistConnLeakShortBody(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ }))
+ defer ts.Close()
+
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+
+ n0 := runtime.NumGoroutine()
+ body := []byte("Hello")
+ for i := 0; i < 20; i++ {
+ req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = int64(len(body) - 2) // explicitly short
+ _, err = c.Do(req)
+ if err == nil {
+ t.Fatal("Expect an error from writing too long of a body.")
+ }
+ }
+ nhigh := runtime.NumGoroutine()
+ tr.CloseIdleConnections()
+ time.Sleep(50 * time.Millisecond)
+ runtime.GC()
+ nfinal := runtime.NumGoroutine()
+
+ growth := nfinal - n0
+
+ // We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
+ // Previously we were leaking one per numReq.
+ t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
+ if int(growth) > 5 {
+ t.Error("too many new goroutines")
+ }
+}
+
// This used to crash; http://golang.org/issue/3266
func TestTransportIdleConnCrash(t *testing.T) {
+ defer checkLeakedTransports(t)
tr := &Transport{}
c := &Client{Transport: tr}
@@ -724,6 +859,361 @@ func TestTransportIdleConnCrash(t *testing.T) {
<-didreq
}
+// Test that the transport doesn't close the TCP connection early,
+// before the response body has been read. This was a regression
+// which sadly lacked a triggering test. The large response body made
+// the old race easier to trigger.
+func TestIssue3644(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const numFoos = 5000
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ for i := 0; i < numFoos; i++ {
+ w.Write([]byte("foo "))
+ }
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ bs, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(bs) != numFoos*len("foo ") {
+ t.Errorf("unexpected response length")
+ }
+}
+
+// Test that a client receives a server's reply, even if the server doesn't read
+// the entire request body.
+func TestIssue3595(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const deniedMsg = "sorry, denied."
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ Error(w, deniedMsg, StatusUnauthorized)
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
+ if err != nil {
+ t.Errorf("Post: %v", err)
+ return
+ }
+ got, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("Body ReadAll: %v", err)
+ }
+ if !strings.Contains(string(got), deniedMsg) {
+ t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
+ }
+}
+
+// From http://golang.org/issue/4454 ,
+// "client fails to handle requests with no body and chunked encoding"
+func TestChunkedNoContent(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusNoContent)
+ }))
+ defer ts.Close()
+
+ for _, closeBody := range []bool{true, false} {
+ c := &Client{Transport: &Transport{}}
+ const n = 4
+ for i := 1; i <= n; i++ {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
+ } else {
+ if closeBody {
+ res.Body.Close()
+ }
+ }
+ }
+ }
+}
+
+func TestTransportConcurrency(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const maxProcs = 16
+ const numReqs = 500
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%v", r.FormValue("echo"))
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ reqs := make(chan string)
+ defer close(reqs)
+
+ var wg sync.WaitGroup
+ wg.Add(numReqs)
+ for i := 0; i < maxProcs*2; i++ {
+ go func() {
+ for req := range reqs {
+ res, err := c.Get(ts.URL + "/?echo=" + req)
+ if err != nil {
+ t.Errorf("error on req %s: %v", req, err)
+ wg.Done()
+ continue
+ }
+ all, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Errorf("read error on req %s: %v", req, err)
+ wg.Done()
+ continue
+ }
+ if string(all) != req {
+ t.Errorf("body of req %s = %q; want %q", req, all, req)
+ }
+ wg.Done()
+ res.Body.Close()
+ }
+ }()
+ }
+ for i := 0; i < numReqs; i++ {
+ reqs <- fmt.Sprintf("request-%d", i)
+ }
+ wg.Wait()
+}
+
+func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const debug = false
+ mux := NewServeMux()
+ mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
+ io.Copy(w, neverEnding('a'))
+ })
+ ts := httptest.NewServer(mux)
+ timeout := 100 * time.Millisecond
+
+ client := &Client{
+ Transport: &Transport{
+ Dial: func(n, addr string) (net.Conn, error) {
+ conn, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ conn.SetDeadline(time.Now().Add(timeout))
+ if debug {
+ conn = NewLoggingConn("client", conn)
+ }
+ return conn, nil
+ },
+ DisableKeepAlives: true,
+ },
+ }
+
+ getFailed := false
+ nRuns := 5
+ if testing.Short() {
+ nRuns = 1
+ }
+ for i := 0; i < nRuns; i++ {
+ if debug {
+ println("run", i+1, "of", nRuns)
+ }
+ sres, err := client.Get(ts.URL + "/get")
+ if err != nil {
+ if !getFailed {
+ // Make the timeout longer, once.
+ getFailed = true
+ t.Logf("increasing timeout")
+ i--
+ timeout *= 10
+ continue
+ }
+ t.Errorf("Error issuing GET: %v", err)
+ break
+ }
+ _, err = io.Copy(ioutil.Discard, sres.Body)
+ if err == nil {
+ t.Errorf("Unexpected successful copy")
+ break
+ }
+ }
+ if debug {
+ println("tests complete; waiting for handlers to finish")
+ }
+ ts.Close()
+}
+
+func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const debug = false
+ mux := NewServeMux()
+ mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
+ io.Copy(w, neverEnding('a'))
+ })
+ mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
+ defer r.Body.Close()
+ io.Copy(ioutil.Discard, r.Body)
+ })
+ ts := httptest.NewServer(mux)
+ timeout := 100 * time.Millisecond
+
+ client := &Client{
+ Transport: &Transport{
+ Dial: func(n, addr string) (net.Conn, error) {
+ conn, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ conn.SetDeadline(time.Now().Add(timeout))
+ if debug {
+ conn = NewLoggingConn("client", conn)
+ }
+ return conn, nil
+ },
+ DisableKeepAlives: true,
+ },
+ }
+
+ getFailed := false
+ nRuns := 5
+ if testing.Short() {
+ nRuns = 1
+ }
+ for i := 0; i < nRuns; i++ {
+ if debug {
+ println("run", i+1, "of", nRuns)
+ }
+ sres, err := client.Get(ts.URL + "/get")
+ if err != nil {
+ if !getFailed {
+ // Make the timeout longer, once.
+ getFailed = true
+ t.Logf("increasing timeout")
+ i--
+ timeout *= 10
+ continue
+ }
+ t.Errorf("Error issuing GET: %v", err)
+ break
+ }
+ req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
+ _, err = client.Do(req)
+ if err == nil {
+ sres.Body.Close()
+ t.Errorf("Unexpected successful PUT")
+ break
+ }
+ sres.Body.Close()
+ }
+ if debug {
+ println("tests complete; waiting for handlers to finish")
+ }
+ ts.Close()
+}
+
+func TestTransportResponseHeaderTimeout(t *testing.T) {
+ defer checkLeakedTransports(t)
+ if testing.Short() {
+ t.Skip("skipping timeout test in -short mode")
+ }
+ mux := NewServeMux()
+ mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {})
+ mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
+ time.Sleep(2 * time.Second)
+ })
+ ts := httptest.NewServer(mux)
+ defer ts.Close()
+
+ tr := &Transport{
+ ResponseHeaderTimeout: 500 * time.Millisecond,
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ tests := []struct {
+ path string
+ want int
+ wantErr string
+ }{
+ {path: "/fast", want: 200},
+ {path: "/slow", wantErr: "timeout awaiting response headers"},
+ {path: "/fast", want: 200},
+ }
+ for i, tt := range tests {
+ res, err := c.Get(ts.URL + tt.path)
+ if err != nil {
+ if strings.Contains(err.Error(), tt.wantErr) {
+ continue
+ }
+ t.Errorf("%d. unexpected error: %v", i, err)
+ continue
+ }
+ if tt.wantErr != "" {
+ t.Errorf("%d. no error. expected error: %v", i, tt.wantErr)
+ continue
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want)
+ }
+ }
+}
+
+func TestTransportCancelRequest(t *testing.T) {
+ defer checkLeakedTransports(t)
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+ unblockc := make(chan bool)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello")
+ w.(Flusher).Flush() // send headers and some body
+ <-unblockc
+ }))
+ defer ts.Close()
+ defer close(unblockc)
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ go func() {
+ time.Sleep(1 * time.Second)
+ tr.CancelRequest(req)
+ }()
+ t0 := time.Now()
+ body, err := ioutil.ReadAll(res.Body)
+ d := time.Since(t0)
+
+ if err == nil {
+ t.Error("expected an error reading the body")
+ }
+ if string(body) != "Hello" {
+ t.Errorf("Body = %q; want Hello", body)
+ }
+ if d < 500*time.Millisecond {
+ t.Errorf("expected ~1 second delay; got %v", d)
+ }
+ // Verify no outstanding requests after readLoop/writeLoop
+ // goroutines shut down.
+ for tries := 3; tries > 0; tries-- {
+ n := tr.NumPendingRequestsForTesting()
+ if n == 0 {
+ break
+ }
+ time.Sleep(100 * time.Millisecond)
+ if tries == 1 {
+ t.Errorf("pending requests = %d; want 0", n)
+ }
+ }
+}
+
type fooProto struct{}
func (fooProto) RoundTrip(req *Request) (*Response, error) {
@@ -737,6 +1227,7 @@ func (fooProto) RoundTrip(req *Request) (*Response, error) {
}
func TestTransportAltProto(t *testing.T) {
+ defer checkLeakedTransports(t)
tr := &Transport{}
c := &Client{Transport: tr}
tr.RegisterProtocol("foo", fooProto{})
@@ -754,15 +1245,58 @@ func TestTransportAltProto(t *testing.T) {
}
}
-var proxyFromEnvTests = []struct {
+func TestTransportNoHost(t *testing.T) {
+ defer checkLeakedTransports(t)
+ tr := &Transport{}
+ _, err := tr.RoundTrip(&Request{
+ Header: make(Header),
+ URL: &url.URL{
+ Scheme: "http",
+ },
+ })
+ want := "http: no Host in request URL"
+ if got := fmt.Sprint(err); got != want {
+ t.Errorf("error = %v; want %q", err, want)
+ }
+}
+
+type proxyFromEnvTest struct {
+ req string // URL to fetch; blank means "http://example.com"
env string
- wanturl string
+ noenv string
+ want string
wanterr error
-}{
- {"127.0.0.1:8080", "http://127.0.0.1:8080", nil},
- {"http://127.0.0.1:8080", "http://127.0.0.1:8080", nil},
- {"https://127.0.0.1:8080", "https://127.0.0.1:8080", nil},
- {"", "<nil>", nil},
+}
+
+func (t proxyFromEnvTest) String() string {
+ var buf bytes.Buffer
+ if t.env != "" {
+ fmt.Fprintf(&buf, "http_proxy=%q", t.env)
+ }
+ if t.noenv != "" {
+ fmt.Fprintf(&buf, " no_proxy=%q", t.noenv)
+ }
+ req := "http://example.com"
+ if t.req != "" {
+ req = t.req
+ }
+ fmt.Fprintf(&buf, " req=%q", req)
+ return strings.TrimSpace(buf.String())
+}
+
+var proxyFromEnvTests = []proxyFromEnvTest{
+ {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
+ {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
+ {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
+ {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
+ {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
+ {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
+ {want: "<nil>"},
+ {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
+ {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
+ {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
+ {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
+ {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
}
func TestProxyFromEnvironment(t *testing.T) {
@@ -770,16 +1304,21 @@ func TestProxyFromEnvironment(t *testing.T) {
os.Setenv("http_proxy", "")
os.Setenv("NO_PROXY", "")
os.Setenv("no_proxy", "")
- for i, tt := range proxyFromEnvTests {
+ for _, tt := range proxyFromEnvTests {
os.Setenv("HTTP_PROXY", tt.env)
- req, _ := NewRequest("GET", "http://example.com", nil)
+ os.Setenv("NO_PROXY", tt.noenv)
+ reqURL := tt.req
+ if reqURL == "" {
+ reqURL = "http://example.com"
+ }
+ req, _ := NewRequest("GET", reqURL, nil)
url, err := ProxyFromEnvironment(req)
if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
- t.Errorf("%d. got error = %q, want %q", i, g, e)
+ t.Errorf("%v: got error = %q, want %q", tt, g, e)
continue
}
- if got := fmt.Sprintf("%s", url); got != tt.wanturl {
- t.Errorf("%d. got URL = %q, want %q", i, url, tt.wanturl)
+ if got := fmt.Sprintf("%s", url); got != tt.want {
+ t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
}
}
}