diff options
author | Michael Stapelberg <stapelberg@debian.org> | 2013-03-04 21:27:36 +0100 |
---|---|---|
committer | Michael Stapelberg <michael@stapelberg.de> | 2013-03-04 21:27:36 +0100 |
commit | 04b08da9af0c450d645ab7389d1467308cfc2db8 (patch) | |
tree | db247935fa4f2f94408edc3acd5d0d4f997aa0d8 /src/pkg/net/http/httputil | |
parent | 917c5fb8ec48e22459d77e3849e6d388f93d3260 (diff) | |
download | golang-upstream/1.1_hg20130304.tar.gz |
Imported Upstream version 1.1~hg20130304upstream/1.1_hg20130304
Diffstat (limited to 'src/pkg/net/http/httputil')
-rw-r--r-- | src/pkg/net/http/httputil/chunked.go | 59 | ||||
-rw-r--r-- | src/pkg/net/http/httputil/chunked_test.go | 54 | ||||
-rw-r--r-- | src/pkg/net/http/httputil/dump.go | 4 | ||||
-rw-r--r-- | src/pkg/net/http/httputil/reverseproxy.go | 58 | ||||
-rw-r--r-- | src/pkg/net/http/httputil/reverseproxy_test.go | 84 |
5 files changed, 213 insertions, 46 deletions
diff --git a/src/pkg/net/http/httputil/chunked.go b/src/pkg/net/http/httputil/chunked.go index 29eaf3475..b66d40951 100644 --- a/src/pkg/net/http/httputil/chunked.go +++ b/src/pkg/net/http/httputil/chunked.go @@ -13,10 +13,9 @@ package httputil import ( "bufio" - "bytes" "errors" + "fmt" "io" - "strconv" ) const maxLineLength = 4096 // assumed <= bufio.defaultBufSize @@ -24,7 +23,7 @@ const maxLineLength = 4096 // assumed <= bufio.defaultBufSize var ErrLineTooLong = errors.New("header line too long") // NewChunkedReader returns a new chunkedReader that translates the data read from r -// out of HTTP "chunked" format before returning it. +// out of HTTP "chunked" format before returning it. // The chunkedReader returns io.EOF when the final 0-length chunk is read. // // NewChunkedReader is not needed by normal applications. The http package @@ -41,16 +40,17 @@ type chunkedReader struct { r *bufio.Reader n uint64 // unread bytes in chunk err error + buf [2]byte } func (cr *chunkedReader) beginChunk() { // chunk-size CRLF - var line string + var line []byte line, cr.err = readLine(cr.r) if cr.err != nil { return } - cr.n, cr.err = strconv.ParseUint(line, 16, 64) + cr.n, cr.err = parseHexUint(line) if cr.err != nil { return } @@ -76,9 +76,8 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { cr.n -= uint64(n) if cr.n == 0 && cr.err == nil { // end of chunk (CRLF) - b := make([]byte, 2) - if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil { - if b[0] != '\r' || b[1] != '\n' { + if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil { + if cr.buf[0] != '\r' || cr.buf[1] != '\n' { cr.err = errors.New("malformed chunked encoding") } } @@ -90,7 +89,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { // Give up if the line exceeds maxLineLength. // The returned bytes are a pointer into storage in // the bufio, so they are only valid until the next bufio read. -func readLineBytes(b *bufio.Reader) (p []byte, err error) { +func readLine(b *bufio.Reader) (p []byte, err error) { if p, err = b.ReadSlice('\n'); err != nil { // We always know when EOF is coming. // If the caller asked for a line, there should be a line. @@ -104,20 +103,18 @@ func readLineBytes(b *bufio.Reader) (p []byte, err error) { if len(p) >= maxLineLength { return nil, ErrLineTooLong } - - // Chop off trailing white space. - p = bytes.TrimRight(p, " \r\t\n") - - return p, nil + return trimTrailingWhitespace(p), nil } -// readLineBytes, but convert the bytes into a string. -func readLine(b *bufio.Reader) (s string, err error) { - p, e := readLineBytes(b) - if e != nil { - return "", e +func trimTrailingWhitespace(b []byte) []byte { + for len(b) > 0 && isASCIISpace(b[len(b)-1]) { + b = b[:len(b)-1] } - return string(p), nil + return b +} + +func isASCIISpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' } // NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP @@ -149,9 +146,7 @@ func (cw *chunkedWriter) Write(data []byte) (n int, err error) { return 0, nil } - head := strconv.FormatInt(int64(len(data)), 16) + "\r\n" - - if _, err = io.WriteString(cw.Wire, head); err != nil { + if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil { return 0, err } if n, err = cw.Wire.Write(data); err != nil { @@ -170,3 +165,21 @@ func (cw *chunkedWriter) Close() error { _, err := io.WriteString(cw.Wire, "0\r\n") return err } + +func parseHexUint(v []byte) (n uint64, err error) { + for _, b := range v { + n <<= 4 + switch { + case '0' <= b && b <= '9': + b = b - '0' + case 'a' <= b && b <= 'f': + b = b - 'a' + 10 + case 'A' <= b && b <= 'F': + b = b - 'A' + 10 + default: + return 0, errors.New("invalid byte in chunk length") + } + n |= uint64(b) + } + return +} diff --git a/src/pkg/net/http/httputil/chunked_test.go b/src/pkg/net/http/httputil/chunked_test.go index 155a32bdf..a06bffad5 100644 --- a/src/pkg/net/http/httputil/chunked_test.go +++ b/src/pkg/net/http/httputil/chunked_test.go @@ -11,7 +11,10 @@ package httputil import ( "bytes" + "fmt" + "io" "io/ioutil" + "runtime" "testing" ) @@ -39,3 +42,54 @@ func TestChunk(t *testing.T) { t.Errorf("chunk reader read %q; want %q", g, e) } } + +func TestChunkReaderAllocs(t *testing.T) { + // temporarily set GOMAXPROCS to 1 as we are testing memory allocations + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + var buf bytes.Buffer + w := NewChunkedWriter(&buf) + a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc") + w.Write(a) + w.Write(b) + w.Write(c) + w.Close() + + r := NewChunkedReader(&buf) + readBuf := make([]byte, len(a)+len(b)+len(c)+1) + + var ms runtime.MemStats + runtime.ReadMemStats(&ms) + m0 := ms.Mallocs + + n, err := io.ReadFull(r, readBuf) + + runtime.ReadMemStats(&ms) + mallocs := ms.Mallocs - m0 + if mallocs > 1 { + t.Errorf("%d mallocs; want <= 1", mallocs) + } + + if n != len(readBuf)-1 { + t.Errorf("read %d bytes; want %d", n, len(readBuf)-1) + } + if err != io.ErrUnexpectedEOF { + t.Errorf("read error = %v; want ErrUnexpectedEOF", err) + } +} + +func TestParseHexUint(t *testing.T) { + for i := uint64(0); i <= 1234; i++ { + line := []byte(fmt.Sprintf("%x", i)) + got, err := parseHexUint(line) + if err != nil { + t.Fatalf("on %d: %v", i, err) + } + if got != i { + t.Errorf("for input %q = %d; want %d", line, got, i) + } + } + _, err := parseHexUint([]byte("bogus")) + if err == nil { + t.Error("expected error on bogus input") + } +} diff --git a/src/pkg/net/http/httputil/dump.go b/src/pkg/net/http/httputil/dump.go index 892ef4ede..0b0035661 100644 --- a/src/pkg/net/http/httputil/dump.go +++ b/src/pkg/net/http/httputil/dump.go @@ -75,7 +75,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { // Use the actual Transport code to record what we would send // on the wire, but not using TCP. Use a Transport with a - // customer dialer that returns a fake net.Conn that waits + // custom dialer that returns a fake net.Conn that waits // for the full input (and recording it), and then responds // with a dummy response. var buf bytes.Buffer // records the output @@ -89,7 +89,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { t := &http.Transport{ Dial: func(net, addr string) (net.Conn, error) { - return &dumpConn{io.MultiWriter(pw, &buf), dr}, nil + return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil }, } diff --git a/src/pkg/net/http/httputil/reverseproxy.go b/src/pkg/net/http/httputil/reverseproxy.go index 9c4bd6e09..134c45299 100644 --- a/src/pkg/net/http/httputil/reverseproxy.go +++ b/src/pkg/net/http/httputil/reverseproxy.go @@ -17,6 +17,10 @@ import ( "time" ) +// onExitFlushLoop is a callback set by tests to detect the state of the +// flushLoop() goroutine. +var onExitFlushLoop func() + // ReverseProxy is an HTTP Handler that takes an incoming request and // sends it to another server, proxying the response back to the // client. @@ -102,8 +106,14 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq.Header.Del("Connection") } - if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { - outreq.Header.Set("X-Forwarded-For", clientIp) + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + outreq.Header.Set("X-Forwarded-For", clientIP) } res, err := transport.RoundTrip(outreq) @@ -112,20 +122,29 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusInternalServerError) return } + defer res.Body.Close() copyHeader(rw.Header(), res.Header) rw.WriteHeader(res.StatusCode) + p.copyResponse(rw, res.Body) +} - if res.Body != nil { - var dst io.Writer = rw - if p.FlushInterval != 0 { - if wf, ok := rw.(writeFlusher); ok { - dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval} +func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { + if p.FlushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: p.FlushInterval, + done: make(chan bool), } + go mlw.flushLoop() + defer mlw.stop() + dst = mlw } - io.Copy(dst, res.Body) } + + io.Copy(dst, src) } type writeFlusher interface { @@ -137,22 +156,14 @@ type maxLatencyWriter struct { dst writeFlusher latency time.Duration - lk sync.Mutex // protects init of done, as well Write + Flush + lk sync.Mutex // protects Write + Flush done chan bool } -func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { +func (m *maxLatencyWriter) Write(p []byte) (int, error) { m.lk.Lock() defer m.lk.Unlock() - if m.done == nil { - m.done = make(chan bool) - go m.flushLoop() - } - n, err = m.dst.Write(p) - if err != nil { - m.done <- true - } - return + return m.dst.Write(p) } func (m *maxLatencyWriter) flushLoop() { @@ -160,13 +171,18 @@ func (m *maxLatencyWriter) flushLoop() { defer t.Stop() for { select { + case <-m.done: + if onExitFlushLoop != nil { + onExitFlushLoop() + } + return case <-t.C: m.lk.Lock() m.dst.Flush() m.lk.Unlock() - case <-m.done: - return } } panic("unreached") } + +func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/src/pkg/net/http/httputil/reverseproxy_test.go b/src/pkg/net/http/httputil/reverseproxy_test.go index 28e9c90ad..863927162 100644 --- a/src/pkg/net/http/httputil/reverseproxy_test.go +++ b/src/pkg/net/http/httputil/reverseproxy_test.go @@ -11,7 +11,9 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" + "time" ) func TestReverseProxy(t *testing.T) { @@ -70,6 +72,47 @@ func TestReverseProxy(t *testing.T) { } } +func TestXForwardedFor(t *testing.T) { + const prevForwardedFor = "client ip" + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { + t.Errorf("X-Forwarded-For didn't contain prior data") + } + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Header.Set("Connection", "close") + getReq.Header.Set("X-Forwarded-For", prevForwardedFor) + getReq.Close = true + res, err := http.DefaultClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + var proxyQueryTests = []struct { baseSuffix string // suffix to add to backend URL reqSuffix string // suffix to add to frontend's request URL @@ -107,3 +150,44 @@ func TestReverseProxyQuery(t *testing.T) { frontend.Close() } } + +func TestReverseProxyFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + done := make(chan bool) + onExitFlushLoop = func() { done <- true } + defer func() { onExitFlushLoop = nil }() + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } + + select { + case <-done: + // OK + case <-time.After(5 * time.Second): + t.Error("maxLatencyWriter flushLoop() never exited") + } +} |