summaryrefslogtreecommitdiff
path: root/src/pkg/net/http/httputil
diff options
context:
space:
mode:
authorMichael Stapelberg <stapelberg@debian.org>2013-03-04 21:27:36 +0100
committerMichael Stapelberg <michael@stapelberg.de>2013-03-04 21:27:36 +0100
commit04b08da9af0c450d645ab7389d1467308cfc2db8 (patch)
treedb247935fa4f2f94408edc3acd5d0d4f997aa0d8 /src/pkg/net/http/httputil
parent917c5fb8ec48e22459d77e3849e6d388f93d3260 (diff)
downloadgolang-upstream/1.1_hg20130304.tar.gz
Imported Upstream version 1.1~hg20130304upstream/1.1_hg20130304
Diffstat (limited to 'src/pkg/net/http/httputil')
-rw-r--r--src/pkg/net/http/httputil/chunked.go59
-rw-r--r--src/pkg/net/http/httputil/chunked_test.go54
-rw-r--r--src/pkg/net/http/httputil/dump.go4
-rw-r--r--src/pkg/net/http/httputil/reverseproxy.go58
-rw-r--r--src/pkg/net/http/httputil/reverseproxy_test.go84
5 files changed, 213 insertions, 46 deletions
diff --git a/src/pkg/net/http/httputil/chunked.go b/src/pkg/net/http/httputil/chunked.go
index 29eaf3475..b66d40951 100644
--- a/src/pkg/net/http/httputil/chunked.go
+++ b/src/pkg/net/http/httputil/chunked.go
@@ -13,10 +13,9 @@ package httputil
import (
"bufio"
- "bytes"
"errors"
+ "fmt"
"io"
- "strconv"
)
const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
@@ -24,7 +23,7 @@ const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
var ErrLineTooLong = errors.New("header line too long")
// NewChunkedReader returns a new chunkedReader that translates the data read from r
-// out of HTTP "chunked" format before returning it.
+// out of HTTP "chunked" format before returning it.
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
//
// NewChunkedReader is not needed by normal applications. The http package
@@ -41,16 +40,17 @@ type chunkedReader struct {
r *bufio.Reader
n uint64 // unread bytes in chunk
err error
+ buf [2]byte
}
func (cr *chunkedReader) beginChunk() {
// chunk-size CRLF
- var line string
+ var line []byte
line, cr.err = readLine(cr.r)
if cr.err != nil {
return
}
- cr.n, cr.err = strconv.ParseUint(line, 16, 64)
+ cr.n, cr.err = parseHexUint(line)
if cr.err != nil {
return
}
@@ -76,9 +76,8 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
cr.n -= uint64(n)
if cr.n == 0 && cr.err == nil {
// end of chunk (CRLF)
- b := make([]byte, 2)
- if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
- if b[0] != '\r' || b[1] != '\n' {
+ if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil {
+ if cr.buf[0] != '\r' || cr.buf[1] != '\n' {
cr.err = errors.New("malformed chunked encoding")
}
}
@@ -90,7 +89,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
// Give up if the line exceeds maxLineLength.
// The returned bytes are a pointer into storage in
// the bufio, so they are only valid until the next bufio read.
-func readLineBytes(b *bufio.Reader) (p []byte, err error) {
+func readLine(b *bufio.Reader) (p []byte, err error) {
if p, err = b.ReadSlice('\n'); err != nil {
// We always know when EOF is coming.
// If the caller asked for a line, there should be a line.
@@ -104,20 +103,18 @@ func readLineBytes(b *bufio.Reader) (p []byte, err error) {
if len(p) >= maxLineLength {
return nil, ErrLineTooLong
}
-
- // Chop off trailing white space.
- p = bytes.TrimRight(p, " \r\t\n")
-
- return p, nil
+ return trimTrailingWhitespace(p), nil
}
-// readLineBytes, but convert the bytes into a string.
-func readLine(b *bufio.Reader) (s string, err error) {
- p, e := readLineBytes(b)
- if e != nil {
- return "", e
+func trimTrailingWhitespace(b []byte) []byte {
+ for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
+ b = b[:len(b)-1]
}
- return string(p), nil
+ return b
+}
+
+func isASCIISpace(b byte) bool {
+ return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
@@ -149,9 +146,7 @@ func (cw *chunkedWriter) Write(data []byte) (n int, err error) {
return 0, nil
}
- head := strconv.FormatInt(int64(len(data)), 16) + "\r\n"
-
- if _, err = io.WriteString(cw.Wire, head); err != nil {
+ if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil {
return 0, err
}
if n, err = cw.Wire.Write(data); err != nil {
@@ -170,3 +165,21 @@ func (cw *chunkedWriter) Close() error {
_, err := io.WriteString(cw.Wire, "0\r\n")
return err
}
+
+func parseHexUint(v []byte) (n uint64, err error) {
+ for _, b := range v {
+ n <<= 4
+ switch {
+ case '0' <= b && b <= '9':
+ b = b - '0'
+ case 'a' <= b && b <= 'f':
+ b = b - 'a' + 10
+ case 'A' <= b && b <= 'F':
+ b = b - 'A' + 10
+ default:
+ return 0, errors.New("invalid byte in chunk length")
+ }
+ n |= uint64(b)
+ }
+ return
+}
diff --git a/src/pkg/net/http/httputil/chunked_test.go b/src/pkg/net/http/httputil/chunked_test.go
index 155a32bdf..a06bffad5 100644
--- a/src/pkg/net/http/httputil/chunked_test.go
+++ b/src/pkg/net/http/httputil/chunked_test.go
@@ -11,7 +11,10 @@ package httputil
import (
"bytes"
+ "fmt"
+ "io"
"io/ioutil"
+ "runtime"
"testing"
)
@@ -39,3 +42,54 @@ func TestChunk(t *testing.T) {
t.Errorf("chunk reader read %q; want %q", g, e)
}
}
+
+func TestChunkReaderAllocs(t *testing.T) {
+ // temporarily set GOMAXPROCS to 1 as we are testing memory allocations
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
+ var buf bytes.Buffer
+ w := NewChunkedWriter(&buf)
+ a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc")
+ w.Write(a)
+ w.Write(b)
+ w.Write(c)
+ w.Close()
+
+ r := NewChunkedReader(&buf)
+ readBuf := make([]byte, len(a)+len(b)+len(c)+1)
+
+ var ms runtime.MemStats
+ runtime.ReadMemStats(&ms)
+ m0 := ms.Mallocs
+
+ n, err := io.ReadFull(r, readBuf)
+
+ runtime.ReadMemStats(&ms)
+ mallocs := ms.Mallocs - m0
+ if mallocs > 1 {
+ t.Errorf("%d mallocs; want <= 1", mallocs)
+ }
+
+ if n != len(readBuf)-1 {
+ t.Errorf("read %d bytes; want %d", n, len(readBuf)-1)
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Errorf("read error = %v; want ErrUnexpectedEOF", err)
+ }
+}
+
+func TestParseHexUint(t *testing.T) {
+ for i := uint64(0); i <= 1234; i++ {
+ line := []byte(fmt.Sprintf("%x", i))
+ got, err := parseHexUint(line)
+ if err != nil {
+ t.Fatalf("on %d: %v", i, err)
+ }
+ if got != i {
+ t.Errorf("for input %q = %d; want %d", line, got, i)
+ }
+ }
+ _, err := parseHexUint([]byte("bogus"))
+ if err == nil {
+ t.Error("expected error on bogus input")
+ }
+}
diff --git a/src/pkg/net/http/httputil/dump.go b/src/pkg/net/http/httputil/dump.go
index 892ef4ede..0b0035661 100644
--- a/src/pkg/net/http/httputil/dump.go
+++ b/src/pkg/net/http/httputil/dump.go
@@ -75,7 +75,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
// Use the actual Transport code to record what we would send
// on the wire, but not using TCP. Use a Transport with a
- // customer dialer that returns a fake net.Conn that waits
+ // custom dialer that returns a fake net.Conn that waits
// for the full input (and recording it), and then responds
// with a dummy response.
var buf bytes.Buffer // records the output
@@ -89,7 +89,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
t := &http.Transport{
Dial: func(net, addr string) (net.Conn, error) {
- return &dumpConn{io.MultiWriter(pw, &buf), dr}, nil
+ return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil
},
}
diff --git a/src/pkg/net/http/httputil/reverseproxy.go b/src/pkg/net/http/httputil/reverseproxy.go
index 9c4bd6e09..134c45299 100644
--- a/src/pkg/net/http/httputil/reverseproxy.go
+++ b/src/pkg/net/http/httputil/reverseproxy.go
@@ -17,6 +17,10 @@ import (
"time"
)
+// onExitFlushLoop is a callback set by tests to detect the state of the
+// flushLoop() goroutine.
+var onExitFlushLoop func()
+
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
@@ -102,8 +106,14 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
outreq.Header.Del("Connection")
}
- if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
- outreq.Header.Set("X-Forwarded-For", clientIp)
+ if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
+ // If we aren't the first proxy retain prior
+ // X-Forwarded-For information as a comma+space
+ // separated list and fold multiple headers into one.
+ if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
+ clientIP = strings.Join(prior, ", ") + ", " + clientIP
+ }
+ outreq.Header.Set("X-Forwarded-For", clientIP)
}
res, err := transport.RoundTrip(outreq)
@@ -112,20 +122,29 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusInternalServerError)
return
}
+ defer res.Body.Close()
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
+ p.copyResponse(rw, res.Body)
+}
- if res.Body != nil {
- var dst io.Writer = rw
- if p.FlushInterval != 0 {
- if wf, ok := rw.(writeFlusher); ok {
- dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval}
+func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
+ if p.FlushInterval != 0 {
+ if wf, ok := dst.(writeFlusher); ok {
+ mlw := &maxLatencyWriter{
+ dst: wf,
+ latency: p.FlushInterval,
+ done: make(chan bool),
}
+ go mlw.flushLoop()
+ defer mlw.stop()
+ dst = mlw
}
- io.Copy(dst, res.Body)
}
+
+ io.Copy(dst, src)
}
type writeFlusher interface {
@@ -137,22 +156,14 @@ type maxLatencyWriter struct {
dst writeFlusher
latency time.Duration
- lk sync.Mutex // protects init of done, as well Write + Flush
+ lk sync.Mutex // protects Write + Flush
done chan bool
}
-func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
+func (m *maxLatencyWriter) Write(p []byte) (int, error) {
m.lk.Lock()
defer m.lk.Unlock()
- if m.done == nil {
- m.done = make(chan bool)
- go m.flushLoop()
- }
- n, err = m.dst.Write(p)
- if err != nil {
- m.done <- true
- }
- return
+ return m.dst.Write(p)
}
func (m *maxLatencyWriter) flushLoop() {
@@ -160,13 +171,18 @@ func (m *maxLatencyWriter) flushLoop() {
defer t.Stop()
for {
select {
+ case <-m.done:
+ if onExitFlushLoop != nil {
+ onExitFlushLoop()
+ }
+ return
case <-t.C:
m.lk.Lock()
m.dst.Flush()
m.lk.Unlock()
- case <-m.done:
- return
}
}
panic("unreached")
}
+
+func (m *maxLatencyWriter) stop() { m.done <- true }
diff --git a/src/pkg/net/http/httputil/reverseproxy_test.go b/src/pkg/net/http/httputil/reverseproxy_test.go
index 28e9c90ad..863927162 100644
--- a/src/pkg/net/http/httputil/reverseproxy_test.go
+++ b/src/pkg/net/http/httputil/reverseproxy_test.go
@@ -11,7 +11,9 @@ import (
"net/http"
"net/http/httptest"
"net/url"
+ "strings"
"testing"
+ "time"
)
func TestReverseProxy(t *testing.T) {
@@ -70,6 +72,47 @@ func TestReverseProxy(t *testing.T) {
}
}
+func TestXForwardedFor(t *testing.T) {
+ const prevForwardedFor = "client ip"
+ const backendResponse = "I am the backend"
+ const backendStatus = 404
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Forwarded-For") == "" {
+ t.Errorf("didn't get X-Forwarded-For header")
+ }
+ if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
+ t.Errorf("X-Forwarded-For didn't contain prior data")
+ }
+ w.WriteHeader(backendStatus)
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Header.Set("Connection", "close")
+ getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
+ getReq.Close = true
+ res, err := http.DefaultClient.Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ bodyBytes, _ := ioutil.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
+
var proxyQueryTests = []struct {
baseSuffix string // suffix to add to backend URL
reqSuffix string // suffix to add to frontend's request URL
@@ -107,3 +150,44 @@ func TestReverseProxyQuery(t *testing.T) {
frontend.Close()
}
}
+
+func TestReverseProxyFlushInterval(t *testing.T) {
+ const expected = "hi"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(expected))
+ }))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.FlushInterval = time.Microsecond
+
+ done := make(chan bool)
+ onExitFlushLoop = func() { done <- true }
+ defer func() { onExitFlushLoop = nil }()
+
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
+ t.Errorf("got body %q; expected %q", bodyBytes, expected)
+ }
+
+ select {
+ case <-done:
+ // OK
+ case <-time.After(5 * time.Second):
+ t.Error("maxLatencyWriter flushLoop() never exited")
+ }
+}