diff options
author | Tianon Gravi <admwiggin@gmail.com> | 2015-01-15 11:54:00 -0700 |
---|---|---|
committer | Tianon Gravi <admwiggin@gmail.com> | 2015-01-15 11:54:00 -0700 |
commit | f154da9e12608589e8d5f0508f908a0c3e88a1bb (patch) | |
tree | f8255d51e10c6f1e0ed69702200b966c9556a431 /src/net/http/httputil | |
parent | 8d8329ed5dfb9622c82a9fbec6fd99a580f9c9f6 (diff) | |
download | golang-upstream/1.4.tar.gz |
Imported Upstream version 1.4upstream/1.4
Diffstat (limited to 'src/net/http/httputil')
-rw-r--r-- | src/net/http/httputil/dump.go | 284 | ||||
-rw-r--r-- | src/net/http/httputil/dump_test.go | 291 | ||||
-rw-r--r-- | src/net/http/httputil/httputil.go | 39 | ||||
-rw-r--r-- | src/net/http/httputil/persist.go | 429 | ||||
-rw-r--r-- | src/net/http/httputil/reverseproxy.go | 225 | ||||
-rw-r--r-- | src/net/http/httputil/reverseproxy_test.go | 213 |
6 files changed, 1481 insertions, 0 deletions
diff --git a/src/net/http/httputil/dump.go b/src/net/http/httputil/dump.go new file mode 100644 index 000000000..ac8f103f9 --- /dev/null +++ b/src/net/http/httputil/dump.go @@ -0,0 +1,284 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + "time" +) + +// One of the copies, say from b to r2, could be avoided by using a more +// elaborate trick where the other copy is made during Request/Response.Write. +// This would complicate things too much, given that these functions are for +// debugging only. +func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { + var buf bytes.Buffer + if _, err = buf.ReadFrom(b); err != nil { + return nil, nil, err + } + if err = b.Close(); err != nil { + return nil, nil, err + } + return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil +} + +// dumpConn is a net.Conn which writes to Writer and reads from Reader +type dumpConn struct { + io.Writer + io.Reader +} + +func (c *dumpConn) Close() error { return nil } +func (c *dumpConn) LocalAddr() net.Addr { return nil } +func (c *dumpConn) RemoteAddr() net.Addr { return nil } +func (c *dumpConn) SetDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +// DumpRequestOut is like DumpRequest but includes +// headers that the standard http.Transport adds, +// such as User-Agent. +func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { + save := req.Body + dummyBody := false + if !body || req.Body == nil { + req.Body = nil + if req.ContentLength != 0 { + req.Body = ioutil.NopCloser(io.LimitReader(neverEnding('x'), req.ContentLength)) + dummyBody = true + } + } else { + var err error + save, req.Body, err = drainBody(req.Body) + if err != nil { + return nil, err + } + } + + // Since we're using the actual Transport code to write the request, + // switch to http so the Transport doesn't try to do an SSL + // negotiation with our dumpConn and its bytes.Buffer & pipe. + // The wire format for https and http are the same, anyway. + reqSend := req + if req.URL.Scheme == "https" { + reqSend = new(http.Request) + *reqSend = *req + reqSend.URL = new(url.URL) + *reqSend.URL = *req.URL + reqSend.URL.Scheme = "http" + } + + // Use the actual Transport code to record what we would send + // on the wire, but not using TCP. Use a Transport with a + // custom dialer that returns a fake net.Conn that waits + // for the full input (and recording it), and then responds + // with a dummy response. + var buf bytes.Buffer // records the output + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + dr := &delegateReader{c: make(chan io.Reader)} + // Wait for the request before replying with a dummy response: + go func() { + req, err := http.ReadRequest(bufio.NewReader(pr)) + if err == nil { + // Ensure all the body is read; otherwise + // we'll get a partial dump. + io.Copy(ioutil.Discard, req.Body) + req.Body.Close() + } + dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\n\r\n") + }() + + t := &http.Transport{ + DisableKeepAlives: true, + Dial: func(net, addr string) (net.Conn, error) { + return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil + }, + } + + _, err := t.RoundTrip(reqSend) + + req.Body = save + if err != nil { + return nil, err + } + dump := buf.Bytes() + + // If we used a dummy body above, remove it now. + // TODO: if the req.ContentLength is large, we allocate memory + // unnecessarily just to slice it off here. But this is just + // a debug function, so this is acceptable for now. We could + // discard the body earlier if this matters. + if dummyBody { + if i := bytes.Index(dump, []byte("\r\n\r\n")); i >= 0 { + dump = dump[:i+4] + } + } + return dump, nil +} + +// delegateReader is a reader that delegates to another reader, +// once it arrives on a channel. +type delegateReader struct { + c chan io.Reader + r io.Reader // nil until received from c +} + +func (r *delegateReader) Read(p []byte) (int, error) { + if r.r == nil { + r.r = <-r.c + } + return r.r.Read(p) +} + +// Return value if nonempty, def otherwise. +func valueOrDefault(value, def string) string { + if value != "" { + return value + } + return def +} + +var reqWriteExcludeHeaderDump = map[string]bool{ + "Host": true, // not in Header map anyway + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + +// dumpAsReceived writes req to w in the form as it was received, or +// at least as accurately as possible from the information retained in +// the request. +func dumpAsReceived(req *http.Request, w io.Writer) error { + return nil +} + +// DumpRequest returns the as-received wire representation of req, +// optionally including the request body, for debugging. +// DumpRequest is semantically a no-op, but in order to +// dump the body, it reads the body data into memory and +// changes req.Body to refer to the in-memory copy. +// The documentation for http.Request.Write details which fields +// of req are used. +func DumpRequest(req *http.Request, body bool) (dump []byte, err error) { + save := req.Body + if !body || req.Body == nil { + req.Body = nil + } else { + save, req.Body, err = drainBody(req.Body) + if err != nil { + return + } + } + + var b bytes.Buffer + + fmt.Fprintf(&b, "%s %s HTTP/%d.%d\r\n", valueOrDefault(req.Method, "GET"), + req.URL.RequestURI(), req.ProtoMajor, req.ProtoMinor) + + host := req.Host + if host == "" && req.URL != nil { + host = req.URL.Host + } + if host != "" { + fmt.Fprintf(&b, "Host: %s\r\n", host) + } + + chunked := len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" + if len(req.TransferEncoding) > 0 { + fmt.Fprintf(&b, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ",")) + } + if req.Close { + fmt.Fprintf(&b, "Connection: close\r\n") + } + + err = req.Header.WriteSubset(&b, reqWriteExcludeHeaderDump) + if err != nil { + return + } + + io.WriteString(&b, "\r\n") + + if req.Body != nil { + var dest io.Writer = &b + if chunked { + dest = NewChunkedWriter(dest) + } + _, err = io.Copy(dest, req.Body) + if chunked { + dest.(io.Closer).Close() + io.WriteString(&b, "\r\n") + } + } + + req.Body = save + if err != nil { + return + } + dump = b.Bytes() + return +} + +// errNoBody is a sentinel error value used by failureToReadBody so we can detect +// that the lack of body was intentional. +var errNoBody = errors.New("sentinel error value") + +// failureToReadBody is a io.ReadCloser that just returns errNoBody on +// Read. It's swapped in when we don't actually want to consume the +// body, but need a non-nil one, and want to distinguish the error +// from reading the dummy body. +type failureToReadBody struct{} + +func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody } +func (failureToReadBody) Close() error { return nil } + +var emptyBody = ioutil.NopCloser(strings.NewReader("")) + +// DumpResponse is like DumpRequest but dumps a response. +func DumpResponse(resp *http.Response, body bool) (dump []byte, err error) { + var b bytes.Buffer + save := resp.Body + savecl := resp.ContentLength + + if !body { + resp.Body = failureToReadBody{} + } else if resp.Body == nil { + resp.Body = emptyBody + } else { + save, resp.Body, err = drainBody(resp.Body) + if err != nil { + return + } + } + err = resp.Write(&b) + if err == errNoBody { + err = nil + } + resp.Body = save + resp.ContentLength = savecl + if err != nil { + return nil, err + } + return b.Bytes(), nil +} diff --git a/src/net/http/httputil/dump_test.go b/src/net/http/httputil/dump_test.go new file mode 100644 index 000000000..024ee5a86 --- /dev/null +++ b/src/net/http/httputil/dump_test.go @@ -0,0 +1,291 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "runtime" + "strings" + "testing" +) + +type dumpTest struct { + Req http.Request + Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body + + WantDump string + WantDumpOut string + NoBody bool // if true, set DumpRequest{,Out} body to false +} + +var dumpTests = []dumpTest{ + + // HTTP/1.1 => chunked coding; body; empty trailer + { + Req: http.Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + TransferEncoding: []string{"chunked"}, + }, + + Body: []byte("abcdef"), + + WantDump: "GET /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("abcdef") + chunk(""), + }, + + // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host, + // and doesn't add a User-Agent. + { + Req: http.Request{ + Method: "GET", + URL: mustParseURL("/foo"), + ProtoMajor: 1, + ProtoMinor: 0, + Header: http.Header{ + "X-Foo": []string{"X-Bar"}, + }, + }, + + WantDump: "GET /foo HTTP/1.0\r\n" + + "X-Foo: X-Bar\r\n\r\n", + }, + + { + Req: *mustNewRequest("GET", "http://example.com/foo", nil), + + WantDumpOut: "GET /foo HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, + + // Test that an https URL doesn't try to do an SSL negotiation + // with a bytes.Buffer and hang with all goroutines not + // runnable. + { + Req: *mustNewRequest("GET", "https://example.com/foo", nil), + + WantDumpOut: "GET /foo HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, + + // Request with Body, but Dump requested without it. + { + Req: http.Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "http", + Host: "post.tld", + Path: "/", + }, + ContentLength: 6, + ProtoMajor: 1, + ProtoMinor: 1, + }, + + Body: []byte("abcdef"), + + WantDumpOut: "POST / HTTP/1.1\r\n" + + "Host: post.tld\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Content-Length: 6\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + + NoBody: true, + }, + + // Request with Body > 8196 (default buffer size) + { + Req: http.Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "http", + Host: "post.tld", + Path: "/", + }, + ContentLength: 8193, + ProtoMajor: 1, + ProtoMinor: 1, + }, + + Body: bytes.Repeat([]byte("a"), 8193), + + WantDumpOut: "POST / HTTP/1.1\r\n" + + "Host: post.tld\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "Content-Length: 8193\r\n" + + "Accept-Encoding: gzip\r\n\r\n" + + strings.Repeat("a", 8193), + }, +} + +func TestDumpRequest(t *testing.T) { + numg0 := runtime.NumGoroutine() + for i, tt := range dumpTests { + setBody := func() { + if tt.Body == nil { + return + } + switch b := tt.Body.(type) { + case []byte: + tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b)) + case func() io.ReadCloser: + tt.Req.Body = b() + default: + t.Fatalf("Test %d: unsupported Body of %T", i, tt.Body) + } + } + setBody() + if tt.Req.Header == nil { + tt.Req.Header = make(http.Header) + } + + if tt.WantDump != "" { + setBody() + dump, err := DumpRequest(&tt.Req, !tt.NoBody) + if err != nil { + t.Errorf("DumpRequest #%d: %s", i, err) + continue + } + if string(dump) != tt.WantDump { + t.Errorf("DumpRequest %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDump, string(dump)) + continue + } + } + + if tt.WantDumpOut != "" { + setBody() + dump, err := DumpRequestOut(&tt.Req, !tt.NoBody) + if err != nil { + t.Errorf("DumpRequestOut #%d: %s", i, err) + continue + } + if string(dump) != tt.WantDumpOut { + t.Errorf("DumpRequestOut %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDumpOut, string(dump)) + continue + } + } + } + if dg := runtime.NumGoroutine() - numg0; dg > 4 { + buf := make([]byte, 4096) + buf = buf[:runtime.Stack(buf, true)] + t.Errorf("Unexpectedly large number of new goroutines: %d new: %s", dg, buf) + } +} + +func chunk(s string) string { + return fmt.Sprintf("%x\r\n%s\r\n", len(s), s) +} + +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(fmt.Sprintf("Error parsing URL %q: %v", s, err)) + } + return u +} + +func mustNewRequest(method, url string, body io.Reader) *http.Request { + req, err := http.NewRequest(method, url, body) + if err != nil { + panic(fmt.Sprintf("NewRequest(%q, %q, %p) err = %v", method, url, body, err)) + } + return req +} + +var dumpResTests = []struct { + res *http.Response + body bool + want string +}{ + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 50, + Header: http.Header{ + "Foo": []string{"Bar"}, + }, + Body: ioutil.NopCloser(strings.NewReader("foo")), // shouldn't be used + }, + body: false, // to verify we see 50, not empty or 3. + want: `HTTP/1.1 200 OK +Content-Length: 50 +Foo: Bar`, + }, + + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 3, + Body: ioutil.NopCloser(strings.NewReader("foo")), + }, + body: true, + want: `HTTP/1.1 200 OK +Content-Length: 3 + +foo`, + }, + + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: -1, + Body: ioutil.NopCloser(strings.NewReader("foo")), + TransferEncoding: []string{"chunked"}, + }, + body: true, + want: `HTTP/1.1 200 OK +Transfer-Encoding: chunked + +3 +foo +0`, + }, +} + +func TestDumpResponse(t *testing.T) { + for i, tt := range dumpResTests { + gotb, err := DumpResponse(tt.res, tt.body) + if err != nil { + t.Errorf("%d. DumpResponse = %v", i, err) + continue + } + got := string(gotb) + got = strings.TrimSpace(got) + got = strings.Replace(got, "\r", "", -1) + + if got != tt.want { + t.Errorf("%d.\nDumpResponse got:\n%s\n\nWant:\n%s\n", i, got, tt.want) + } + } +} diff --git a/src/net/http/httputil/httputil.go b/src/net/http/httputil/httputil.go new file mode 100644 index 000000000..2e523e9e2 --- /dev/null +++ b/src/net/http/httputil/httputil.go @@ -0,0 +1,39 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httputil provides HTTP utility functions, complementing the +// more common ones in the net/http package. +package httputil + +import ( + "io" + "net/http/internal" +) + +// NewChunkedReader returns a new chunkedReader that translates the data read from r +// out of HTTP "chunked" format before returning it. +// The chunkedReader returns io.EOF when the final 0-length chunk is read. +// +// NewChunkedReader is not needed by normal applications. The http package +// automatically decodes chunking when reading response bodies. +func NewChunkedReader(r io.Reader) io.Reader { + return internal.NewChunkedReader(r) +} + +// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream. +// +// NewChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using NewChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func NewChunkedWriter(w io.Writer) io.WriteCloser { + return internal.NewChunkedWriter(w) +} + +// ErrLineTooLong is returned when reading malformed chunked data +// with lines that are too long. +var ErrLineTooLong = internal.ErrLineTooLong diff --git a/src/net/http/httputil/persist.go b/src/net/http/httputil/persist.go new file mode 100644 index 000000000..987bcc96b --- /dev/null +++ b/src/net/http/httputil/persist.go @@ -0,0 +1,429 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil + +import ( + "bufio" + "errors" + "io" + "net" + "net/http" + "net/textproto" + "sync" +) + +var ( + ErrPersistEOF = &http.ProtocolError{ErrorString: "persistent connection closed"} + ErrClosed = &http.ProtocolError{ErrorString: "connection closed by user"} + ErrPipeline = &http.ProtocolError{ErrorString: "pipeline error"} +) + +// This is an API usage error - the local side is closed. +// ErrPersistEOF (above) reports that the remote side is closed. +var errClosed = errors.New("i/o operation on closed connection") + +// A ServerConn reads requests and sends responses over an underlying +// connection, until the HTTP keepalive logic commands an end. ServerConn +// also allows hijacking the underlying connection by calling Hijack +// to regain control over the connection. ServerConn supports pipe-lining, +// i.e. requests can be read out of sync (but in the same order) while the +// respective responses are sent. +// +// ServerConn is low-level and old. Applications should instead use Server +// in the net/http package. +type ServerConn struct { + lk sync.Mutex // read-write protects the following fields + c net.Conn + r *bufio.Reader + re, we error // read/write errors + lastbody io.ReadCloser + nread, nwritten int + pipereq map[*http.Request]uint + + pipe textproto.Pipeline +} + +// NewServerConn returns a new ServerConn reading and writing c. If r is not +// nil, it is the buffer to use when reading c. +// +// ServerConn is low-level and old. Applications should instead use Server +// in the net/http package. +func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { + if r == nil { + r = bufio.NewReader(c) + } + return &ServerConn{c: c, r: r, pipereq: make(map[*http.Request]uint)} +} + +// Hijack detaches the ServerConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be +// called before Read has signaled the end of the keep-alive logic. The user +// should not call Hijack while Read or Write is in progress. +func (sc *ServerConn) Hijack() (c net.Conn, r *bufio.Reader) { + sc.lk.Lock() + defer sc.lk.Unlock() + c = sc.c + r = sc.r + sc.c = nil + sc.r = nil + return +} + +// Close calls Hijack and then also closes the underlying connection +func (sc *ServerConn) Close() error { + c, _ := sc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + +// Read returns the next request on the wire. An ErrPersistEOF is returned if +// it is gracefully determined that there are no more requests (e.g. after the +// first request on an HTTP/1.0 connection, or after a Connection:close on a +// HTTP/1.1 connection). +func (sc *ServerConn) Read() (req *http.Request, err error) { + + // Ensure ordered execution of Reads and Writes + id := sc.pipe.Next() + sc.pipe.StartRequest(id) + defer func() { + sc.pipe.EndRequest(id) + if req == nil { + sc.pipe.StartResponse(id) + sc.pipe.EndResponse(id) + } else { + // Remember the pipeline id of this request + sc.lk.Lock() + sc.pipereq[req] = id + sc.lk.Unlock() + } + }() + + sc.lk.Lock() + if sc.we != nil { // no point receiving if write-side broken or closed + defer sc.lk.Unlock() + return nil, sc.we + } + if sc.re != nil { + defer sc.lk.Unlock() + return nil, sc.re + } + if sc.r == nil { // connection closed by user in the meantime + defer sc.lk.Unlock() + return nil, errClosed + } + r := sc.r + lastbody := sc.lastbody + sc.lastbody = nil + sc.lk.Unlock() + + // Make sure body is fully consumed, even if user does not call body.Close + if lastbody != nil { + // body.Close is assumed to be idempotent and multiple calls to + // it should return the error that its first invocation + // returned. + err = lastbody.Close() + if err != nil { + sc.lk.Lock() + defer sc.lk.Unlock() + sc.re = err + return nil, err + } + } + + req, err = http.ReadRequest(r) + sc.lk.Lock() + defer sc.lk.Unlock() + if err != nil { + if err == io.ErrUnexpectedEOF { + // A close from the opposing client is treated as a + // graceful close, even if there was some unparse-able + // data before the close. + sc.re = ErrPersistEOF + return nil, sc.re + } else { + sc.re = err + return req, err + } + } + sc.lastbody = req.Body + sc.nread++ + if req.Close { + sc.re = ErrPersistEOF + return req, sc.re + } + return req, err +} + +// Pending returns the number of unanswered requests +// that have been received on the connection. +func (sc *ServerConn) Pending() int { + sc.lk.Lock() + defer sc.lk.Unlock() + return sc.nread - sc.nwritten +} + +// Write writes resp in response to req. To close the connection gracefully, set the +// Response.Close field to true. Write should be considered operational until +// it returns an error, regardless of any errors returned on the Read side. +func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error { + + // Retrieve the pipeline ID of this request/response pair + sc.lk.Lock() + id, ok := sc.pipereq[req] + delete(sc.pipereq, req) + if !ok { + sc.lk.Unlock() + return ErrPipeline + } + sc.lk.Unlock() + + // Ensure pipeline order + sc.pipe.StartResponse(id) + defer sc.pipe.EndResponse(id) + + sc.lk.Lock() + if sc.we != nil { + defer sc.lk.Unlock() + return sc.we + } + if sc.c == nil { // connection closed by user in the meantime + defer sc.lk.Unlock() + return ErrClosed + } + c := sc.c + if sc.nread <= sc.nwritten { + defer sc.lk.Unlock() + return errors.New("persist server pipe count") + } + if resp.Close { + // After signaling a keep-alive close, any pipelined unread + // requests will be lost. It is up to the user to drain them + // before signaling. + sc.re = ErrPersistEOF + } + sc.lk.Unlock() + + err := resp.Write(c) + sc.lk.Lock() + defer sc.lk.Unlock() + if err != nil { + sc.we = err + return err + } + sc.nwritten++ + + return nil +} + +// A ClientConn sends request and receives headers over an underlying +// connection, while respecting the HTTP keepalive logic. ClientConn +// supports hijacking the connection calling Hijack to +// regain control of the underlying net.Conn and deal with it as desired. +// +// ClientConn is low-level and old. Applications should instead use +// Client or Transport in the net/http package. +type ClientConn struct { + lk sync.Mutex // read-write protects the following fields + c net.Conn + r *bufio.Reader + re, we error // read/write errors + lastbody io.ReadCloser + nread, nwritten int + pipereq map[*http.Request]uint + + pipe textproto.Pipeline + writeReq func(*http.Request, io.Writer) error +} + +// NewClientConn returns a new ClientConn reading and writing c. If r is not +// nil, it is the buffer to use when reading c. +// +// ClientConn is low-level and old. Applications should use Client or +// Transport in the net/http package. +func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { + if r == nil { + r = bufio.NewReader(c) + } + return &ClientConn{ + c: c, + r: r, + pipereq: make(map[*http.Request]uint), + writeReq: (*http.Request).Write, + } +} + +// NewProxyClientConn works like NewClientConn but writes Requests +// using Request's WriteProxy method. +// +// New code should not use NewProxyClientConn. See Client or +// Transport in the net/http package instead. +func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { + cc := NewClientConn(c, r) + cc.writeReq = (*http.Request).WriteProxy + return cc +} + +// Hijack detaches the ClientConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be +// called before the user or Read have signaled the end of the keep-alive +// logic. The user should not call Hijack while Read or Write is in progress. +func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) { + cc.lk.Lock() + defer cc.lk.Unlock() + c = cc.c + r = cc.r + cc.c = nil + cc.r = nil + return +} + +// Close calls Hijack and then also closes the underlying connection +func (cc *ClientConn) Close() error { + c, _ := cc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + +// Write writes a request. An ErrPersistEOF error is returned if the connection +// has been closed in an HTTP keepalive sense. If req.Close equals true, the +// keepalive connection is logically closed after this request and the opposing +// server is informed. An ErrUnexpectedEOF indicates the remote closed the +// underlying TCP connection, which is usually considered as graceful close. +func (cc *ClientConn) Write(req *http.Request) (err error) { + + // Ensure ordered execution of Writes + id := cc.pipe.Next() + cc.pipe.StartRequest(id) + defer func() { + cc.pipe.EndRequest(id) + if err != nil { + cc.pipe.StartResponse(id) + cc.pipe.EndResponse(id) + } else { + // Remember the pipeline id of this request + cc.lk.Lock() + cc.pipereq[req] = id + cc.lk.Unlock() + } + }() + + cc.lk.Lock() + if cc.re != nil { // no point sending if read-side closed or broken + defer cc.lk.Unlock() + return cc.re + } + if cc.we != nil { + defer cc.lk.Unlock() + return cc.we + } + if cc.c == nil { // connection closed by user in the meantime + defer cc.lk.Unlock() + return errClosed + } + c := cc.c + if req.Close { + // We write the EOF to the write-side error, because there + // still might be some pipelined reads + cc.we = ErrPersistEOF + } + cc.lk.Unlock() + + err = cc.writeReq(req, c) + cc.lk.Lock() + defer cc.lk.Unlock() + if err != nil { + cc.we = err + return err + } + cc.nwritten++ + + return nil +} + +// Pending returns the number of unanswered requests +// that have been sent on the connection. +func (cc *ClientConn) Pending() int { + cc.lk.Lock() + defer cc.lk.Unlock() + return cc.nwritten - cc.nread +} + +// Read reads the next response from the wire. A valid response might be +// returned together with an ErrPersistEOF, which means that the remote +// requested that this be the last request serviced. Read can be called +// concurrently with Write, but not with another Read. +func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) { + // Retrieve the pipeline ID of this request/response pair + cc.lk.Lock() + id, ok := cc.pipereq[req] + delete(cc.pipereq, req) + if !ok { + cc.lk.Unlock() + return nil, ErrPipeline + } + cc.lk.Unlock() + + // Ensure pipeline order + cc.pipe.StartResponse(id) + defer cc.pipe.EndResponse(id) + + cc.lk.Lock() + if cc.re != nil { + defer cc.lk.Unlock() + return nil, cc.re + } + if cc.r == nil { // connection closed by user in the meantime + defer cc.lk.Unlock() + return nil, errClosed + } + r := cc.r + lastbody := cc.lastbody + cc.lastbody = nil + cc.lk.Unlock() + + // Make sure body is fully consumed, even if user does not call body.Close + if lastbody != nil { + // body.Close is assumed to be idempotent and multiple calls to + // it should return the error that its first invocation + // returned. + err = lastbody.Close() + if err != nil { + cc.lk.Lock() + defer cc.lk.Unlock() + cc.re = err + return nil, err + } + } + + resp, err = http.ReadResponse(r, req) + cc.lk.Lock() + defer cc.lk.Unlock() + if err != nil { + cc.re = err + return resp, err + } + cc.lastbody = resp.Body + + cc.nread++ + + if resp.Close { + cc.re = ErrPersistEOF // don't send any more requests + return resp, cc.re + } + return resp, err +} + +// Do is convenience method that writes a request and reads a response. +func (cc *ClientConn) Do(req *http.Request) (resp *http.Response, err error) { + err = cc.Write(req) + if err != nil { + return + } + return cc.Read(req) +} diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go new file mode 100644 index 000000000..ab4637018 --- /dev/null +++ b/src/net/http/httputil/reverseproxy.go @@ -0,0 +1,225 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package httputil + +import ( + "io" + "log" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// onExitFlushLoop is a callback set by tests to detect the state of the +// flushLoop() goroutine. +var onExitFlushLoop func() + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +type ReverseProxy struct { + // Director must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + Director func(*http.Request) + + // The transport used to perform proxy requests. + // If nil, http.DefaultTransport is used. + Transport http.RoundTripper + + // FlushInterval specifies the flush interval + // to flush to the client while copying the + // response body. + // If zero, no periodic flushing is done. + FlushInterval time.Duration + + // ErrorLog specifies an optional logger for errors + // that occur when attempting to proxy the request. + // If nil, logging goes to os.Stderr via the log package's + // standard logger. + ErrorLog *log.Logger +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { + targetQuery := target.RawQuery + director := func(req *http.Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } + } + return &ReverseProxy{Director: director} +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +var hopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailers", + "Transfer-Encoding", + "Upgrade", +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + transport := p.Transport + if transport == nil { + transport = http.DefaultTransport + } + + outreq := new(http.Request) + *outreq = *req // includes shallow copies of maps, but okay + + p.Director(outreq) + outreq.Proto = "HTTP/1.1" + outreq.ProtoMajor = 1 + outreq.ProtoMinor = 1 + outreq.Close = false + + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. This + // is modifying the same underlying map from req (shallow + // copied above) so we only copy it if necessary. + copiedHeaders := false + for _, h := range hopHeaders { + if outreq.Header.Get(h) != "" { + if !copiedHeaders { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, req.Header) + copiedHeaders = true + } + outreq.Header.Del(h) + } + } + + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + outreq.Header.Set("X-Forwarded-For", clientIP) + } + + res, err := transport.RoundTrip(outreq) + if err != nil { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusInternalServerError) + return + } + defer res.Body.Close() + + for _, h := range hopHeaders { + res.Header.Del(h) + } + + copyHeader(rw.Header(), res.Header) + + rw.WriteHeader(res.StatusCode) + p.copyResponse(rw, res.Body) +} + +func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { + if p.FlushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: p.FlushInterval, + done: make(chan bool), + } + go mlw.flushLoop() + defer mlw.stop() + dst = mlw + } + } + + io.Copy(dst, src) +} + +func (p *ReverseProxy) logf(format string, args ...interface{}) { + if p.ErrorLog != nil { + p.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +type writeFlusher interface { + io.Writer + http.Flusher +} + +type maxLatencyWriter struct { + dst writeFlusher + latency time.Duration + + lk sync.Mutex // protects Write + Flush + done chan bool +} + +func (m *maxLatencyWriter) Write(p []byte) (int, error) { + m.lk.Lock() + defer m.lk.Unlock() + return m.dst.Write(p) +} + +func (m *maxLatencyWriter) flushLoop() { + t := time.NewTicker(m.latency) + defer t.Stop() + for { + select { + case <-m.done: + if onExitFlushLoop != nil { + onExitFlushLoop() + } + return + case <-t.C: + m.lk.Lock() + m.dst.Flush() + m.lk.Unlock() + } + } +} + +func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go new file mode 100644 index 000000000..e9539b44b --- /dev/null +++ b/src/net/http/httputil/reverseproxy_test.go @@ -0,0 +1,213 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Reverse proxy tests. + +package httputil + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +const fakeHopHeader = "X-Fake-Hop-Header-For-Test" + +func init() { + hopHeaders = append(hopHeaders, fakeHopHeader) +} + +func TestReverseProxy(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(r.TransferEncoding) > 0 { + t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) + } + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if c := r.Header.Get("Connection"); c != "" { + t.Errorf("handler got Connection header value %q", c) + } + if c := r.Header.Get("Upgrade"); c != "" { + t.Errorf("handler got Upgrade header value %q", c) + } + if g, e := r.Host, "some-name"; g != e { + t.Errorf("backend got Host header %q, want %q", g, e) + } + w.Header().Set("X-Foo", "bar") + w.Header().Set("Upgrade", "foo") + w.Header().Set(fakeHopHeader, "foo") + w.Header().Add("X-Multi-Value", "foo") + w.Header().Add("X-Multi-Value", "bar") + http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Header.Set("Connection", "close") + getReq.Header.Set("Upgrade", "foo") + getReq.Close = true + res, err := http.DefaultClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if g, e := res.Header.Get("X-Foo"), "bar"; g != e { + t.Errorf("got X-Foo %q; expected %q", g, e) + } + if c := res.Header.Get(fakeHopHeader); c != "" { + t.Errorf("got %s header value %q", fakeHopHeader, c) + } + if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { + t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) + } + if g, e := len(res.Header["Set-Cookie"]), 1; g != e { + t.Fatalf("got %d SetCookies, want %d", g, e) + } + if cookie := res.Cookies()[0]; cookie.Name != "flavor" { + t.Errorf("unexpected cookie %q", cookie.Name) + } + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +func TestXForwardedFor(t *testing.T) { + const prevForwardedFor = "client ip" + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { + t.Errorf("X-Forwarded-For didn't contain prior data") + } + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Header.Set("Connection", "close") + getReq.Header.Set("X-Forwarded-For", prevForwardedFor) + getReq.Close = true + res, err := http.DefaultClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +var proxyQueryTests = []struct { + baseSuffix string // suffix to add to backend URL + reqSuffix string // suffix to add to frontend's request URL + want string // what backend should see for final request URL (without ?) +}{ + {"", "", ""}, + {"?sta=tic", "?us=er", "sta=tic&us=er"}, + {"", "?us=er", "us=er"}, + {"?sta=tic", "", "sta=tic"}, +} + +func TestReverseProxyQuery(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Got-Query", r.URL.RawQuery) + w.Write([]byte("hi")) + })) + defer backend.Close() + + for i, tt := range proxyQueryTests { + backendURL, err := url.Parse(backend.URL + tt.baseSuffix) + if err != nil { + t.Fatal(err) + } + frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) + req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) + req.Close = true + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("%d. Get: %v", i, err) + } + if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { + t.Errorf("%d. got query %q; expected %q", i, g, e) + } + res.Body.Close() + frontend.Close() + } +} + +func TestReverseProxyFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + done := make(chan bool) + onExitFlushLoop = func() { done <- true } + defer func() { onExitFlushLoop = nil }() + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } + + select { + case <-done: + // OK + case <-time.After(5 * time.Second): + t.Error("maxLatencyWriter flushLoop() never exited") + } +} |