summaryrefslogtreecommitdiff
path: root/src/pkg/http
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/http')
-rw-r--r--src/pkg/http/Makefile1
-rw-r--r--src/pkg/http/cgi/host.go117
-rw-r--r--src/pkg/http/cgi/host_test.go37
-rwxr-xr-xsrc/pkg/http/cgi/testdata/test.cgi5
-rw-r--r--src/pkg/http/client.go59
-rw-r--r--src/pkg/http/client_test.go49
-rw-r--r--src/pkg/http/cookie.go21
-rw-r--r--src/pkg/http/cookie_test.go6
-rw-r--r--src/pkg/http/dump.go2
-rw-r--r--src/pkg/http/export_test.go7
-rw-r--r--src/pkg/http/fcgi/Makefile12
-rw-r--r--src/pkg/http/fcgi/child.go328
-rw-r--r--src/pkg/http/fcgi/fcgi.go271
-rw-r--r--src/pkg/http/fcgi/fcgi_test.go114
-rw-r--r--src/pkg/http/fs.go2
-rw-r--r--src/pkg/http/fs_test.go2
-rw-r--r--src/pkg/http/httptest/recorder.go2
-rw-r--r--src/pkg/http/persist.go44
-rw-r--r--src/pkg/http/proxy_test.go23
-rw-r--r--src/pkg/http/request.go158
-rw-r--r--src/pkg/http/request_test.go122
-rw-r--r--src/pkg/http/requestwrite_test.go61
-rw-r--r--src/pkg/http/response_test.go135
-rw-r--r--src/pkg/http/reverseproxy.go100
-rw-r--r--src/pkg/http/reverseproxy_test.go50
-rw-r--r--src/pkg/http/serve_test.go103
-rw-r--r--src/pkg/http/server.go105
-rw-r--r--src/pkg/http/transfer.go12
-rw-r--r--src/pkg/http/transport.go54
-rw-r--r--src/pkg/http/transport_test.go117
30 files changed, 1989 insertions, 130 deletions
diff --git a/src/pkg/http/Makefile b/src/pkg/http/Makefile
index 389b04222..2a2a2a3be 100644
--- a/src/pkg/http/Makefile
+++ b/src/pkg/http/Makefile
@@ -16,6 +16,7 @@ GOFILES=\
persist.go\
request.go\
response.go\
+ reverseproxy.go\
server.go\
status.go\
transfer.go\
diff --git a/src/pkg/http/cgi/host.go b/src/pkg/http/cgi/host.go
index a713d7c3c..136d4e4ee 100644
--- a/src/pkg/http/cgi/host.go
+++ b/src/pkg/http/cgi/host.go
@@ -25,20 +25,40 @@ import (
"os"
"path/filepath"
"regexp"
+ "runtime"
"strconv"
"strings"
)
var trailingPort = regexp.MustCompile(`:([0-9]+)$`)
+var osDefaultInheritEnv = map[string][]string{
+ "darwin": []string{"DYLD_LIBRARY_PATH"},
+ "freebsd": []string{"LD_LIBRARY_PATH"},
+ "hpux": []string{"LD_LIBRARY_PATH", "SHLIB_PATH"},
+ "linux": []string{"LD_LIBRARY_PATH"},
+ "windows": []string{"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"},
+}
+
// Handler runs an executable in a subprocess with a CGI environment.
type Handler struct {
Path string // path to the CGI executable
Root string // root URI prefix of handler or empty for "/"
- Env []string // extra environment variables to set, if any
- Logger *log.Logger // optional log for errors or nil to use log.Print
- Args []string // optional arguments to pass to child process
+ Env []string // extra environment variables to set, if any, as "key=value"
+ InheritEnv []string // environment variables to inherit from host, as "key"
+ Logger *log.Logger // optional log for errors or nil to use log.Print
+ Args []string // optional arguments to pass to child process
+
+ // PathLocationHandler specifies the root http Handler that
+ // should handle internal redirects when the CGI process
+ // returns a Location header value starting with a "/", as
+ // specified in RFC 3875 ยง 6.3.2. This will likely be
+ // http.DefaultServeMux.
+ //
+ // If nil, a CGI response with a local URI path is instead sent
+ // back to the client and not redirected internally.
+ PathLocationHandler http.Handler
}
func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
@@ -110,6 +130,24 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
env = append(env, h.Env...)
}
+ path := os.Getenv("PATH")
+ if path == "" {
+ path = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin"
+ }
+ env = append(env, "PATH="+path)
+
+ for _, e := range h.InheritEnv {
+ if v := os.Getenv(e); v != "" {
+ env = append(env, e+"="+v)
+ }
+ }
+
+ for _, e := range osDefaultInheritEnv[runtime.GOOS] {
+ if v := os.Getenv(e); v != "" {
+ env = append(env, e+"="+v)
+ }
+ }
+
cwd, pathBase := filepath.Split(h.Path)
if cwd == "" {
cwd = "."
@@ -143,13 +181,13 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
linebody, _ := bufio.NewReaderSize(cmd.Stdout, 1024)
- headers := rw.Header()
- statusCode := http.StatusOK
+ headers := make(http.Header)
+ statusCode := 0
for {
line, isPrefix, err := linebody.ReadLine()
if isPrefix {
rw.WriteHeader(http.StatusInternalServerError)
- h.printf("CGI: long header line from subprocess.")
+ h.printf("cgi: long header line from subprocess.")
return
}
if err == os.EOF {
@@ -157,7 +195,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
- h.printf("CGI: error reading headers: %v", err)
+ h.printf("cgi: error reading headers: %v", err)
return
}
if len(line) == 0 {
@@ -165,7 +203,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
parts := strings.Split(string(line), ":", 2)
if len(parts) < 2 {
- h.printf("CGI: bogus header line: %s", string(line))
+ h.printf("cgi: bogus header line: %s", string(line))
continue
}
header, val := parts[0], parts[1]
@@ -174,13 +212,13 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
switch {
case header == "Status":
if len(val) < 3 {
- h.printf("CGI: bogus status (short): %q", val)
+ h.printf("cgi: bogus status (short): %q", val)
return
}
code, err := strconv.Atoi(val[0:3])
if err != nil {
- h.printf("CGI: bogus status: %q", val)
- h.printf("CGI: line was %q", line)
+ h.printf("cgi: bogus status: %q", val)
+ h.printf("cgi: line was %q", line)
return
}
statusCode = code
@@ -188,11 +226,35 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
headers.Add(header, val)
}
}
+
+ if loc := headers.Get("Location"); loc != "" {
+ if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil {
+ h.handleInternalRedirect(rw, req, loc)
+ return
+ }
+ if statusCode == 0 {
+ statusCode = http.StatusFound
+ }
+ }
+
+ if statusCode == 0 {
+ statusCode = http.StatusOK
+ }
+
+ // Copy headers to rw's headers, after we've decided not to
+ // go into handleInternalRedirect, which won't want its rw
+ // headers to have been touched.
+ for k, vv := range headers {
+ for _, v := range vv {
+ rw.Header().Add(k, v)
+ }
+ }
+
rw.WriteHeader(statusCode)
_, err = io.Copy(rw, linebody)
if err != nil {
- h.printf("CGI: copy error: %v", err)
+ h.printf("cgi: copy error: %v", err)
}
}
@@ -204,6 +266,37 @@ func (h *Handler) printf(format string, v ...interface{}) {
}
}
+func (h *Handler) handleInternalRedirect(rw http.ResponseWriter, req *http.Request, path string) {
+ url, err := req.URL.ParseURL(path)
+ if err != nil {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: error resolving local URI path %q: %v", path, err)
+ return
+ }
+ // TODO: RFC 3875 isn't clear if only GET is supported, but it
+ // suggests so: "Note that any message-body attached to the
+ // request (such as for a POST request) may not be available
+ // to the resource that is the target of the redirect." We
+ // should do some tests against Apache to see how it handles
+ // POST, HEAD, etc. Does the internal redirect get the same
+ // method or just GET? What about incoming headers?
+ // (e.g. Cookies) Which headers, if any, are copied into the
+ // second request?
+ newReq := &http.Request{
+ Method: "GET",
+ URL: url,
+ RawURL: path,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(http.Header),
+ Host: url.Host,
+ RemoteAddr: req.RemoteAddr,
+ TLS: req.TLS,
+ }
+ h.PathLocationHandler.ServeHTTP(rw, newReq)
+}
+
func upperCaseAndUnderscore(rune int) int {
switch {
case rune >= 'a' && rune <= 'z':
diff --git a/src/pkg/http/cgi/host_test.go b/src/pkg/http/cgi/host_test.go
index e8084b113..9ac085f2f 100644
--- a/src/pkg/http/cgi/host_test.go
+++ b/src/pkg/http/cgi/host_test.go
@@ -271,3 +271,40 @@ Transfer-Encoding: chunked
expected, got)
}
}
+
+func TestRedirect(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ rec := runCgiTest(t, h, "GET /test.cgi?loc=http://foo.com/ HTTP/1.0\nHost: example.com\n\n", nil)
+ if e, g := 302, rec.Code; e != g {
+ t.Errorf("expected status code %d; got %d", e, g)
+ }
+ if e, g := "http://foo.com/", rec.Header().Get("Location"); e != g {
+ t.Errorf("expected Location header of %q; got %q", e, g)
+ }
+}
+
+func TestInternalRedirect(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
+ baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path)
+ fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr)
+ })
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ PathLocationHandler: baseHandler,
+ }
+ expectedMap := map[string]string{
+ "basepath": "/foo",
+ "remoteaddr": "1.2.3.4",
+ }
+ runCgiTest(t, h, "GET /test.cgi?loc=/foo HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
diff --git a/src/pkg/http/cgi/testdata/test.cgi b/src/pkg/http/cgi/testdata/test.cgi
index 253589eed..a1b2ff893 100755
--- a/src/pkg/http/cgi/testdata/test.cgi
+++ b/src/pkg/http/cgi/testdata/test.cgi
@@ -11,6 +11,11 @@ use CGI;
my $q = CGI->new;
my $params = $q->Vars;
+if ($params->{"loc"}) {
+ print "Location: $params->{loc}\r\n\r\n";
+ exit(0);
+}
+
my $NL = "\r\n";
$NL = "\n" if $params->{mode} eq "NL";
diff --git a/src/pkg/http/client.go b/src/pkg/http/client.go
index daba3a89b..d73cbc855 100644
--- a/src/pkg/http/client.go
+++ b/src/pkg/http/client.go
@@ -22,6 +22,16 @@ import (
// Client is not yet very configurable.
type Client struct {
Transport RoundTripper // if nil, DefaultTransport is used
+
+ // If CheckRedirect is not nil, the client calls it before
+ // following an HTTP redirect. The arguments req and via
+ // are the upcoming request and the requests made already,
+ // oldest first. If CheckRedirect returns an error, the client
+ // returns that error instead of issue the Request req.
+ //
+ // If CheckRedirect is nil, the Client uses its default policy,
+ // which is to stop after 10 consecutive requests.
+ CheckRedirect func(req *Request, via []*Request) os.Error
}
// DefaultClient is the default Client and is used by Get, Head, and Post.
@@ -109,7 +119,7 @@ func shouldRedirect(statusCode int) bool {
}
// Get issues a GET to the specified URL. If the response is one of the following
-// redirect codes, it follows the redirect, up to a maximum of 10 redirects:
+// redirect codes, Get follows the redirect, up to a maximum of 10 redirects:
//
// 301 (Moved Permanently)
// 302 (Found)
@@ -126,35 +136,33 @@ func Get(url string) (r *Response, finalURL string, err os.Error) {
return DefaultClient.Get(url)
}
-// Get issues a GET to the specified URL. If the response is one of the following
-// redirect codes, it follows the redirect, up to a maximum of 10 redirects:
+// Get issues a GET to the specified URL. If the response is one of the
+// following redirect codes, Get follows the redirect after calling the
+// Client's CheckRedirect function.
//
// 301 (Moved Permanently)
// 302 (Found)
// 303 (See Other)
// 307 (Temporary Redirect)
//
-// finalURL is the URL from which the response was fetched -- identical to the
-// input URL unless redirects were followed.
+// finalURL is the URL from which the response was fetched -- identical
+// to the input URL unless redirects were followed.
//
// Caller should close r.Body when done reading from it.
func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) {
// TODO: if/when we add cookie support, the redirected request shouldn't
// necessarily supply the same cookies as the original.
- // TODO: set referrer header on redirects.
var base *URL
- // TODO: remove this hard-coded 10 and use the Client's policy
- // (ClientConfig) instead.
- for redirect := 0; ; redirect++ {
- if redirect >= 10 {
- err = os.ErrorString("stopped after 10 redirects")
- break
- }
+ redirectChecker := c.CheckRedirect
+ if redirectChecker == nil {
+ redirectChecker = defaultCheckRedirect
+ }
+ var via []*Request
+ for redirect := 0; ; redirect++ {
var req Request
req.Method = "GET"
- req.ProtoMajor = 1
- req.ProtoMinor = 1
+ req.Header = make(Header)
if base == nil {
req.URL, err = ParseURL(url)
} else {
@@ -163,6 +171,19 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) {
if err != nil {
break
}
+ if len(via) > 0 {
+ // Add the Referer header.
+ lastReq := via[len(via)-1]
+ if lastReq.URL.Scheme != "https" {
+ req.Referer = lastReq.URL.String()
+ }
+
+ err = redirectChecker(&req, via)
+ if err != nil {
+ break
+ }
+ }
+
url = req.URL.String()
if r, err = send(&req, c.Transport); err != nil {
break
@@ -174,6 +195,7 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) {
break
}
base = req.URL
+ via = append(via, &req)
continue
}
finalURL = url
@@ -184,6 +206,13 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) {
return
}
+func defaultCheckRedirect(req *Request, via []*Request) os.Error {
+ if len(via) >= 10 {
+ return os.ErrorString("stopped after 10 redirects")
+ }
+ return nil
+}
+
// Post issues a POST to the specified URL.
//
// Caller should close r.Body when done reading from it.
diff --git a/src/pkg/http/client_test.go b/src/pkg/http/client_test.go
index 3a6f83425..59d62c1c9 100644
--- a/src/pkg/http/client_test.go
+++ b/src/pkg/http/client_test.go
@@ -12,6 +12,7 @@ import (
"http/httptest"
"io/ioutil"
"os"
+ "strconv"
"strings"
"testing"
)
@@ -75,3 +76,51 @@ func TestGetRequestFormat(t *testing.T) {
t.Errorf("expected non-nil request Header")
}
}
+
+func TestRedirects(t *testing.T) {
+ var ts *httptest.Server
+ ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ n, _ := strconv.Atoi(r.FormValue("n"))
+ // Test Referer header. (7 is arbitrary position to test at)
+ if n == 7 {
+ if g, e := r.Referer, ts.URL+"/?n=6"; e != g {
+ t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g)
+ }
+ }
+ if n < 15 {
+ Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound)
+ return
+ }
+ fmt.Fprintf(w, "n=%d", n)
+ }))
+ defer ts.Close()
+
+ c := &Client{}
+ _, _, err := c.Get(ts.URL)
+ if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with default client, expected error %q, got %q", e, g)
+ }
+
+ var checkErr os.Error
+ var lastVia []*Request
+ c = &Client{CheckRedirect: func(_ *Request, via []*Request) os.Error {
+ lastVia = via
+ return checkErr
+ }}
+ _, finalUrl, err := c.Get(ts.URL)
+ if e, g := "<nil>", fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with custom client, expected error %q, got %q", e, g)
+ }
+ if !strings.HasSuffix(finalUrl, "/?n=15") {
+ t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl)
+ }
+ if e, g := 15, len(lastVia); e != g {
+ t.Errorf("expected lastVia to have contained %d elements; got %d", e, g)
+ }
+
+ checkErr = os.NewError("no redirects allowed")
+ _, finalUrl, err = c.Get(ts.URL)
+ if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with redirects forbidden, expected error %q, got %q", e, g)
+ }
+}
diff --git a/src/pkg/http/cookie.go b/src/pkg/http/cookie.go
index 2bb66e58e..2c01826a1 100644
--- a/src/pkg/http/cookie.go
+++ b/src/pkg/http/cookie.go
@@ -142,12 +142,12 @@ func writeSetCookies(w io.Writer, kk []*Cookie) os.Error {
var b bytes.Buffer
for _, c := range kk {
b.Reset()
- fmt.Fprintf(&b, "%s=%s", c.Name, c.Value)
+ fmt.Fprintf(&b, "%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value))
if len(c.Path) > 0 {
- fmt.Fprintf(&b, "; Path=%s", URLEscape(c.Path))
+ fmt.Fprintf(&b, "; Path=%s", sanitizeValue(c.Path))
}
if len(c.Domain) > 0 {
- fmt.Fprintf(&b, "; Domain=%s", URLEscape(c.Domain))
+ fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(c.Domain))
}
if len(c.Expires.Zone) > 0 {
fmt.Fprintf(&b, "; Expires=%s", c.Expires.Format(time.RFC1123))
@@ -225,7 +225,7 @@ func readCookies(h Header) []*Cookie {
func writeCookies(w io.Writer, kk []*Cookie) os.Error {
lines := make([]string, 0, len(kk))
for _, c := range kk {
- lines = append(lines, fmt.Sprintf("Cookie: %s=%s\r\n", c.Name, c.Value))
+ lines = append(lines, fmt.Sprintf("Cookie: %s=%s\r\n", sanitizeName(c.Name), sanitizeValue(c.Value)))
}
sort.SortStrings(lines)
for _, l := range lines {
@@ -236,6 +236,19 @@ func writeCookies(w io.Writer, kk []*Cookie) os.Error {
return nil
}
+func sanitizeName(n string) string {
+ n = strings.Replace(n, "\n", "-", -1)
+ n = strings.Replace(n, "\r", "-", -1)
+ return n
+}
+
+func sanitizeValue(v string) string {
+ v = strings.Replace(v, "\n", " ", -1)
+ v = strings.Replace(v, "\r", " ", -1)
+ v = strings.Replace(v, ";", " ", -1)
+ return v
+}
+
func unquoteCookieValue(v string) string {
if len(v) > 1 && v[0] == '"' && v[len(v)-1] == '"' {
return v[1 : len(v)-1]
diff --git a/src/pkg/http/cookie_test.go b/src/pkg/http/cookie_test.go
index db0997040..a3ae85cd6 100644
--- a/src/pkg/http/cookie_test.go
+++ b/src/pkg/http/cookie_test.go
@@ -21,9 +21,13 @@ var writeSetCookiesTests = []struct {
[]*Cookie{
&Cookie{Name: "cookie-1", Value: "v$1"},
&Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600},
+ &Cookie{Name: "cookie-3", Value: "three", Domain: ".example.com"},
+ &Cookie{Name: "cookie-4", Value: "four", Path: "/restricted/"},
},
"Set-Cookie: cookie-1=v$1\r\n" +
- "Set-Cookie: cookie-2=two; Max-Age=3600\r\n",
+ "Set-Cookie: cookie-2=two; Max-Age=3600\r\n" +
+ "Set-Cookie: cookie-3=three; Domain=.example.com\r\n" +
+ "Set-Cookie: cookie-4=four; Path=/restricted/\r\n",
},
}
diff --git a/src/pkg/http/dump.go b/src/pkg/http/dump.go
index 306c45bc2..358980f7c 100644
--- a/src/pkg/http/dump.go
+++ b/src/pkg/http/dump.go
@@ -31,6 +31,8 @@ func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err os.Error) {
// DumpRequest is semantically a no-op, but in order to
// dump the body, it reads the body data into memory and
// changes req.Body to refer to the in-memory copy.
+// The documentation for Request.Write details which fields
+// of req are used.
func DumpRequest(req *Request, body bool) (dump []byte, err os.Error) {
var b bytes.Buffer
save := req.Body
diff --git a/src/pkg/http/export_test.go b/src/pkg/http/export_test.go
index 47c687760..3fe658641 100644
--- a/src/pkg/http/export_test.go
+++ b/src/pkg/http/export_test.go
@@ -32,3 +32,10 @@ func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
}
return len(conns)
}
+
+func NewTestTimeoutHandler(handler Handler, ch <-chan int64) Handler {
+ f := func() <-chan int64 {
+ return ch
+ }
+ return &timeoutHandler{handler, f, ""}
+}
diff --git a/src/pkg/http/fcgi/Makefile b/src/pkg/http/fcgi/Makefile
new file mode 100644
index 000000000..bc01cdea9
--- /dev/null
+++ b/src/pkg/http/fcgi/Makefile
@@ -0,0 +1,12 @@
+# Copyright 2011 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../Make.inc
+
+TARG=http/fcgi
+GOFILES=\
+ child.go\
+ fcgi.go\
+
+include ../../../Make.pkg
diff --git a/src/pkg/http/fcgi/child.go b/src/pkg/http/fcgi/child.go
new file mode 100644
index 000000000..114052bee
--- /dev/null
+++ b/src/pkg/http/fcgi/child.go
@@ -0,0 +1,328 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fcgi
+
+// This file implements FastCGI from the perspective of a child process.
+
+import (
+ "fmt"
+ "http"
+ "io"
+ "net"
+ "os"
+ "strconv"
+ "strings"
+ "time"
+)
+
+// request holds the state for an in-progress request. As soon as it's complete,
+// it's converted to an http.Request.
+type request struct {
+ pw *io.PipeWriter
+ reqId uint16
+ params map[string]string
+ buf [1024]byte
+ rawParams []byte
+ keepConn bool
+}
+
+func newRequest(reqId uint16, flags uint8) *request {
+ r := &request{
+ reqId: reqId,
+ params: map[string]string{},
+ keepConn: flags&flagKeepConn != 0,
+ }
+ r.rawParams = r.buf[:0]
+ return r
+}
+
+// TODO(eds): copied from http/cgi
+var skipHeader = map[string]bool{
+ "HTTP_HOST": true,
+ "HTTP_REFERER": true,
+ "HTTP_USER_AGENT": true,
+}
+
+// httpRequest converts r to an http.Request.
+// TODO(eds): this is very similar to http/cgi's requestFromEnvironment
+func (r *request) httpRequest(body io.ReadCloser) (*http.Request, os.Error) {
+ req := &http.Request{
+ Method: r.params["REQUEST_METHOD"],
+ RawURL: r.params["REQUEST_URI"],
+ Body: body,
+ Header: http.Header{},
+ Trailer: http.Header{},
+ Proto: r.params["SERVER_PROTOCOL"],
+ }
+
+ var ok bool
+ req.ProtoMajor, req.ProtoMinor, ok = http.ParseHTTPVersion(req.Proto)
+ if !ok {
+ return nil, os.NewError("fcgi: invalid HTTP version")
+ }
+
+ req.Host = r.params["HTTP_HOST"]
+ req.Referer = r.params["HTTP_REFERER"]
+ req.UserAgent = r.params["HTTP_USER_AGENT"]
+
+ if lenstr := r.params["CONTENT_LENGTH"]; lenstr != "" {
+ clen, err := strconv.Atoi64(r.params["CONTENT_LENGTH"])
+ if err != nil {
+ return nil, os.NewError("fcgi: bad CONTENT_LENGTH parameter: " + lenstr)
+ }
+ req.ContentLength = clen
+ }
+
+ if req.Host != "" {
+ req.RawURL = "http://" + req.Host + r.params["REQUEST_URI"]
+ url, err := http.ParseURL(req.RawURL)
+ if err != nil {
+ return nil, os.NewError("fcgi: failed to parse host and REQUEST_URI into a URL: " + req.RawURL)
+ }
+ req.URL = url
+ }
+ if req.URL == nil {
+ req.RawURL = r.params["REQUEST_URI"]
+ url, err := http.ParseURL(req.RawURL)
+ if err != nil {
+ return nil, os.NewError("fcgi: failed to parse REQUEST_URI into a URL: " + req.RawURL)
+ }
+ req.URL = url
+ }
+
+ for key, val := range r.params {
+ if strings.HasPrefix(key, "HTTP_") && !skipHeader[key] {
+ req.Header.Add(strings.Replace(key[5:], "_", "-", -1), val)
+ }
+ }
+ return req, nil
+}
+
+// parseParams reads an encoded []byte into Params.
+func (r *request) parseParams() {
+ text := r.rawParams
+ r.rawParams = nil
+ for len(text) > 0 {
+ keyLen, n := readSize(text)
+ if n == 0 {
+ return
+ }
+ text = text[n:]
+ valLen, n := readSize(text)
+ if n == 0 {
+ return
+ }
+ text = text[n:]
+ key := readString(text, keyLen)
+ text = text[keyLen:]
+ val := readString(text, valLen)
+ text = text[valLen:]
+ r.params[key] = val
+ }
+}
+
+// response implements http.ResponseWriter.
+type response struct {
+ req *request
+ header http.Header
+ w *bufWriter
+ wroteHeader bool
+}
+
+func newResponse(c *child, req *request) *response {
+ return &response{
+ req: req,
+ header: http.Header{},
+ w: newWriter(c.conn, typeStdout, req.reqId),
+ }
+}
+
+func (r *response) Header() http.Header {
+ return r.header
+}
+
+func (r *response) Write(data []byte) (int, os.Error) {
+ if !r.wroteHeader {
+ r.WriteHeader(http.StatusOK)
+ }
+ return r.w.Write(data)
+}
+
+func (r *response) WriteHeader(code int) {
+ if r.wroteHeader {
+ return
+ }
+ r.wroteHeader = true
+ if code == http.StatusNotModified {
+ // Must not have body.
+ r.header.Del("Content-Type")
+ r.header.Del("Content-Length")
+ r.header.Del("Transfer-Encoding")
+ } else if r.header.Get("Content-Type") == "" {
+ r.header.Set("Content-Type", "text/html; charset=utf-8")
+ }
+
+ if r.header.Get("Date") == "" {
+ r.header.Set("Date", time.UTC().Format(http.TimeFormat))
+ }
+
+ fmt.Fprintf(r.w, "Status: %d %s\r\n", code, http.StatusText(code))
+ // TODO(eds): this is duplicated in http and http/cgi
+ for k, vv := range r.header {
+ for _, v := range vv {
+ v = strings.Replace(v, "\n", "", -1)
+ v = strings.Replace(v, "\r", "", -1)
+ v = strings.TrimSpace(v)
+ fmt.Fprintf(r.w, "%s: %s\r\n", k, v)
+ }
+ }
+ r.w.WriteString("\r\n")
+}
+
+func (r *response) Flush() {
+ if !r.wroteHeader {
+ r.WriteHeader(http.StatusOK)
+ }
+ r.w.Flush()
+}
+
+func (r *response) Close() os.Error {
+ r.Flush()
+ return r.w.Close()
+}
+
+type child struct {
+ conn *conn
+ handler http.Handler
+}
+
+func newChild(rwc net.Conn, handler http.Handler) *child {
+ return &child{newConn(rwc), handler}
+}
+
+func (c *child) serve() {
+ requests := map[uint16]*request{}
+ defer c.conn.Close()
+ var rec record
+ var br beginRequest
+ for {
+ if err := rec.read(c.conn.rwc); err != nil {
+ return
+ }
+
+ req, ok := requests[rec.h.Id]
+ if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues {
+ // The spec says to ignore unknown request IDs.
+ continue
+ }
+ if ok && rec.h.Type == typeBeginRequest {
+ // The server is trying to begin a request with the same ID
+ // as an in-progress request. This is an error.
+ return
+ }
+
+ switch rec.h.Type {
+ case typeBeginRequest:
+ if err := br.read(rec.content()); err != nil {
+ return
+ }
+ if br.role != roleResponder {
+ c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole)
+ break
+ }
+ requests[rec.h.Id] = newRequest(rec.h.Id, br.flags)
+ case typeParams:
+ // NOTE(eds): Technically a key-value pair can straddle the boundary
+ // between two packets. We buffer until we've received all parameters.
+ if len(rec.content()) > 0 {
+ req.rawParams = append(req.rawParams, rec.content()...)
+ break
+ }
+ req.parseParams()
+ case typeStdin:
+ content := rec.content()
+ if req.pw == nil {
+ var body io.ReadCloser
+ if len(content) > 0 {
+ // body could be an io.LimitReader, but it shouldn't matter
+ // as long as both sides are behaving.
+ body, req.pw = io.Pipe()
+ }
+ go c.serveRequest(req, body)
+ }
+ if len(content) > 0 {
+ // TODO(eds): This blocks until the handler reads from the pipe.
+ // If the handler takes a long time, it might be a problem.
+ req.pw.Write(content)
+ } else if req.pw != nil {
+ req.pw.Close()
+ }
+ case typeGetValues:
+ values := map[string]string{"FCGI_MPXS_CONNS": "1"}
+ c.conn.writePairs(0, typeGetValuesResult, values)
+ case typeData:
+ // If the filter role is implemented, read the data stream here.
+ case typeAbortRequest:
+ requests[rec.h.Id] = nil, false
+ c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete)
+ if !req.keepConn {
+ // connection will close upon return
+ return
+ }
+ default:
+ b := make([]byte, 8)
+ b[0] = rec.h.Type
+ c.conn.writeRecord(typeUnknownType, 0, b)
+ }
+ }
+}
+
+func (c *child) serveRequest(req *request, body io.ReadCloser) {
+ r := newResponse(c, req)
+ httpReq, err := req.httpRequest(body)
+ if err != nil {
+ // there was an error reading the request
+ r.WriteHeader(http.StatusInternalServerError)
+ c.conn.writeRecord(typeStderr, req.reqId, []byte(err.String()))
+ } else {
+ c.handler.ServeHTTP(r, httpReq)
+ }
+ if body != nil {
+ body.Close()
+ }
+ r.Close()
+ c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete)
+ if !req.keepConn {
+ c.conn.Close()
+ }
+}
+
+// Serve accepts incoming FastCGI connections on the listener l, creating a new
+// service thread for each. The service threads read requests and then call handler
+// to reply to them.
+// If l is nil, Serve accepts connections on stdin.
+// If handler is nil, http.DefaultServeMux is used.
+func Serve(l net.Listener, handler http.Handler) os.Error {
+ if l == nil {
+ var err os.Error
+ l, err = net.FileListener(os.Stdin)
+ if err != nil {
+ return err
+ }
+ defer l.Close()
+ }
+ if handler == nil {
+ handler = http.DefaultServeMux
+ }
+ for {
+ rw, err := l.Accept()
+ if err != nil {
+ return err
+ }
+ c := newChild(rw, handler)
+ go c.serve()
+ }
+ panic("unreachable")
+}
diff --git a/src/pkg/http/fcgi/fcgi.go b/src/pkg/http/fcgi/fcgi.go
new file mode 100644
index 000000000..8e2e1cd3c
--- /dev/null
+++ b/src/pkg/http/fcgi/fcgi.go
@@ -0,0 +1,271 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package fcgi implements the FastCGI protocol.
+// Currently only the responder role is supported.
+// The protocol is defined at http://www.fastcgi.com/drupal/node/6?q=node/22
+package fcgi
+
+// This file defines the raw protocol and some utilities used by the child and
+// the host.
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/binary"
+ "io"
+ "os"
+ "sync"
+)
+
+const (
+ // Packet Types
+ typeBeginRequest = iota + 1
+ typeAbortRequest
+ typeEndRequest
+ typeParams
+ typeStdin
+ typeStdout
+ typeStderr
+ typeData
+ typeGetValues
+ typeGetValuesResult
+ typeUnknownType
+)
+
+// keep the connection between web-server and responder open after request
+const flagKeepConn = 1
+
+const (
+ maxWrite = 65535 // maximum record body
+ maxPad = 255
+)
+
+const (
+ roleResponder = iota + 1 // only Responders are implemented.
+ roleAuthorizer
+ roleFilter
+)
+
+const (
+ statusRequestComplete = iota
+ statusCantMultiplex
+ statusOverloaded
+ statusUnknownRole
+)
+
+const headerLen = 8
+
+type header struct {
+ Version uint8
+ Type uint8
+ Id uint16
+ ContentLength uint16
+ PaddingLength uint8
+ Reserved uint8
+}
+
+type beginRequest struct {
+ role uint16
+ flags uint8
+ reserved [5]uint8
+}
+
+func (br *beginRequest) read(content []byte) os.Error {
+ if len(content) != 8 {
+ return os.NewError("fcgi: invalid begin request record")
+ }
+ br.role = binary.BigEndian.Uint16(content)
+ br.flags = content[2]
+ return nil
+}
+
+// for padding so we don't have to allocate all the time
+// not synchronized because we don't care what the contents are
+var pad [maxPad]byte
+
+func (h *header) init(recType uint8, reqId uint16, contentLength int) {
+ h.Version = 1
+ h.Type = recType
+ h.Id = reqId
+ h.ContentLength = uint16(contentLength)
+ h.PaddingLength = uint8(-contentLength & 7)
+}
+
+// conn sends records over rwc
+type conn struct {
+ mutex sync.Mutex
+ rwc io.ReadWriteCloser
+
+ // to avoid allocations
+ buf bytes.Buffer
+ h header
+}
+
+func newConn(rwc io.ReadWriteCloser) *conn {
+ return &conn{rwc: rwc}
+}
+
+func (c *conn) Close() os.Error {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ return c.rwc.Close()
+}
+
+type record struct {
+ h header
+ buf [maxWrite + maxPad]byte
+}
+
+func (rec *record) read(r io.Reader) (err os.Error) {
+ if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil {
+ return err
+ }
+ if rec.h.Version != 1 {
+ return os.NewError("fcgi: invalid header version")
+ }
+ n := int(rec.h.ContentLength) + int(rec.h.PaddingLength)
+ if _, err = io.ReadFull(r, rec.buf[:n]); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (r *record) content() []byte {
+ return r.buf[:r.h.ContentLength]
+}
+
+// writeRecord writes and sends a single record.
+func (c *conn) writeRecord(recType uint8, reqId uint16, b []byte) os.Error {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ c.buf.Reset()
+ c.h.init(recType, reqId, len(b))
+ if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil {
+ return err
+ }
+ if _, err := c.buf.Write(b); err != nil {
+ return err
+ }
+ if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil {
+ return err
+ }
+ _, err := c.rwc.Write(c.buf.Bytes())
+ return err
+}
+
+func (c *conn) writeBeginRequest(reqId uint16, role uint16, flags uint8) os.Error {
+ b := [8]byte{byte(role >> 8), byte(role), flags}
+ return c.writeRecord(typeBeginRequest, reqId, b[:])
+}
+
+func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8) os.Error {
+ b := make([]byte, 8)
+ binary.BigEndian.PutUint32(b, uint32(appStatus))
+ b[4] = protocolStatus
+ return c.writeRecord(typeEndRequest, reqId, b)
+}
+
+func (c *conn) writePairs(recType uint8, reqId uint16, pairs map[string]string) os.Error {
+ w := newWriter(c, recType, reqId)
+ b := make([]byte, 8)
+ for k, v := range pairs {
+ n := encodeSize(b, uint32(len(k)))
+ n += encodeSize(b[n:], uint32(len(k)))
+ if _, err := w.Write(b[:n]); err != nil {
+ return err
+ }
+ if _, err := w.WriteString(k); err != nil {
+ return err
+ }
+ if _, err := w.WriteString(v); err != nil {
+ return err
+ }
+ }
+ w.Close()
+ return nil
+}
+
+func readSize(s []byte) (uint32, int) {
+ if len(s) == 0 {
+ return 0, 0
+ }
+ size, n := uint32(s[0]), 1
+ if size&(1<<7) != 0 {
+ if len(s) < 4 {
+ return 0, 0
+ }
+ n = 4
+ size = binary.BigEndian.Uint32(s)
+ size &^= 1 << 31
+ }
+ return size, n
+}
+
+func readString(s []byte, size uint32) string {
+ if size > uint32(len(s)) {
+ return ""
+ }
+ return string(s[:size])
+}
+
+func encodeSize(b []byte, size uint32) int {
+ if size > 127 {
+ size |= 1 << 31
+ binary.BigEndian.PutUint32(b, size)
+ return 4
+ }
+ b[0] = byte(size)
+ return 1
+}
+
+// bufWriter encapsulates bufio.Writer but also closes the underlying stream when
+// Closed.
+type bufWriter struct {
+ closer io.Closer
+ *bufio.Writer
+}
+
+func (w *bufWriter) Close() os.Error {
+ if err := w.Writer.Flush(); err != nil {
+ w.closer.Close()
+ return err
+ }
+ return w.closer.Close()
+}
+
+func newWriter(c *conn, recType uint8, reqId uint16) *bufWriter {
+ s := &streamWriter{c: c, recType: recType, reqId: reqId}
+ w, _ := bufio.NewWriterSize(s, maxWrite)
+ return &bufWriter{s, w}
+}
+
+// streamWriter abstracts out the separation of a stream into discrete records.
+// It only writes maxWrite bytes at a time.
+type streamWriter struct {
+ c *conn
+ recType uint8
+ reqId uint16
+}
+
+func (w *streamWriter) Write(p []byte) (int, os.Error) {
+ nn := 0
+ for len(p) > 0 {
+ n := len(p)
+ if n > maxWrite {
+ n = maxWrite
+ }
+ if err := w.c.writeRecord(w.recType, w.reqId, p[:n]); err != nil {
+ return nn, err
+ }
+ nn += n
+ p = p[n:]
+ }
+ return nn, nil
+}
+
+func (w *streamWriter) Close() os.Error {
+ // send empty record to close the stream
+ return w.c.writeRecord(w.recType, w.reqId, nil)
+}
diff --git a/src/pkg/http/fcgi/fcgi_test.go b/src/pkg/http/fcgi/fcgi_test.go
new file mode 100644
index 000000000..16a624329
--- /dev/null
+++ b/src/pkg/http/fcgi/fcgi_test.go
@@ -0,0 +1,114 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fcgi
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "testing"
+)
+
+var sizeTests = []struct {
+ size uint32
+ bytes []byte
+}{
+ {0, []byte{0x00}},
+ {127, []byte{0x7F}},
+ {128, []byte{0x80, 0x00, 0x00, 0x80}},
+ {1000, []byte{0x80, 0x00, 0x03, 0xE8}},
+ {33554431, []byte{0x81, 0xFF, 0xFF, 0xFF}},
+}
+
+func TestSize(t *testing.T) {
+ b := make([]byte, 4)
+ for i, test := range sizeTests {
+ n := encodeSize(b, test.size)
+ if !bytes.Equal(b[:n], test.bytes) {
+ t.Errorf("%d expected %x, encoded %x", i, test.bytes, b)
+ }
+ size, n := readSize(test.bytes)
+ if size != test.size {
+ t.Errorf("%d expected %d, read %d", i, test.size, size)
+ }
+ if len(test.bytes) != n {
+ t.Errorf("%d did not consume all the bytes", i)
+ }
+ }
+}
+
+var streamTests = []struct {
+ desc string
+ recType uint8
+ reqId uint16
+ content []byte
+ raw []byte
+}{
+ {"single record", typeStdout, 1, nil,
+ []byte{1, typeStdout, 0, 1, 0, 0, 0, 0},
+ },
+ // this data will have to be split into two records
+ {"two records", typeStdin, 300, make([]byte, 66000),
+ bytes.Join([][]byte{
+ // header for the first record
+ []byte{1, typeStdin, 0x01, 0x2C, 0xFF, 0xFF, 1, 0},
+ make([]byte, 65536),
+ // header for the second
+ []byte{1, typeStdin, 0x01, 0x2C, 0x01, 0xD1, 7, 0},
+ make([]byte, 472),
+ // header for the empty record
+ []byte{1, typeStdin, 0x01, 0x2C, 0, 0, 0, 0},
+ },
+ nil),
+ },
+}
+
+type nilCloser struct {
+ io.ReadWriter
+}
+
+func (c *nilCloser) Close() os.Error { return nil }
+
+func TestStreams(t *testing.T) {
+ var rec record
+outer:
+ for _, test := range streamTests {
+ buf := bytes.NewBuffer(test.raw)
+ var content []byte
+ for buf.Len() > 0 {
+ if err := rec.read(buf); err != nil {
+ t.Errorf("%s: error reading record: %v", test.desc, err)
+ continue outer
+ }
+ content = append(content, rec.content()...)
+ }
+ if rec.h.Type != test.recType {
+ t.Errorf("%s: got type %d expected %d", test.desc, rec.h.Type, test.recType)
+ continue
+ }
+ if rec.h.Id != test.reqId {
+ t.Errorf("%s: got request ID %d expected %d", test.desc, rec.h.Id, test.reqId)
+ continue
+ }
+ if !bytes.Equal(content, test.content) {
+ t.Errorf("%s: read wrong content", test.desc)
+ continue
+ }
+ buf.Reset()
+ c := newConn(&nilCloser{buf})
+ w := newWriter(c, test.recType, test.reqId)
+ if _, err := w.Write(test.content); err != nil {
+ t.Errorf("%s: error writing record: %v", test.desc, err)
+ continue
+ }
+ if err := w.Close(); err != nil {
+ t.Errorf("%s: error closing stream: %v", test.desc, err)
+ continue
+ }
+ if !bytes.Equal(buf.Bytes(), test.raw) {
+ t.Errorf("%s: wrote wrong content", test.desc)
+ }
+ }
+}
diff --git a/src/pkg/http/fs.go b/src/pkg/http/fs.go
index c5efffca9..17d5297b8 100644
--- a/src/pkg/http/fs.go
+++ b/src/pkg/http/fs.go
@@ -143,7 +143,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) {
n, _ := io.ReadFull(f, buf[:])
b := buf[:n]
if isText(b) {
- ctype = "text-plain; charset=utf-8"
+ ctype = "text/plain; charset=utf-8"
} else {
// generic binary
ctype = "application/octet-stream"
diff --git a/src/pkg/http/fs_test.go b/src/pkg/http/fs_test.go
index 692b9863e..09d0981f2 100644
--- a/src/pkg/http/fs_test.go
+++ b/src/pkg/http/fs_test.go
@@ -104,7 +104,7 @@ func TestServeFileContentType(t *testing.T) {
t.Errorf("Content-Type mismatch: got %q, want %q", h, want)
}
}
- get("text-plain; charset=utf-8")
+ get("text/plain; charset=utf-8")
override = true
get(ctype)
}
diff --git a/src/pkg/http/httptest/recorder.go b/src/pkg/http/httptest/recorder.go
index 0dd19a617..f2fedefcf 100644
--- a/src/pkg/http/httptest/recorder.go
+++ b/src/pkg/http/httptest/recorder.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// The httptest package provides utilities for HTTP testing.
+// Package httptest provides utilities for HTTP testing.
package httptest
import (
diff --git a/src/pkg/http/persist.go b/src/pkg/http/persist.go
index b93c5fe48..e4eea6815 100644
--- a/src/pkg/http/persist.go
+++ b/src/pkg/http/persist.go
@@ -20,8 +20,8 @@ var (
// A ServerConn reads requests and sends responses over an underlying
// connection, until the HTTP keepalive logic commands an end. ServerConn
-// does not close the underlying connection. Instead, the user calls Close
-// and regains control over the connection. ServerConn supports pipe-lining,
+// also allows hijacking the underlying connection by calling Hijack
+// to regain control over the connection. ServerConn supports pipe-lining,
// i.e. requests can be read out of sync (but in the same order) while the
// respective responses are sent.
type ServerConn struct {
@@ -45,11 +45,11 @@ func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn {
return &ServerConn{c: c, r: r, pipereq: make(map[*Request]uint)}
}
-// Close detaches the ServerConn and returns the underlying connection as well
-// as the read-side bufio which may have some left over data. Close may be
+// Hijack detaches the ServerConn and returns the underlying connection as well
+// as the read-side bufio which may have some left over data. Hijack may be
// called before Read has signaled the end of the keep-alive logic. The user
-// should not call Close while Read or Write is in progress.
-func (sc *ServerConn) Close() (c net.Conn, r *bufio.Reader) {
+// should not call Hijack while Read or Write is in progress.
+func (sc *ServerConn) Hijack() (c net.Conn, r *bufio.Reader) {
sc.lk.Lock()
defer sc.lk.Unlock()
c = sc.c
@@ -59,6 +59,15 @@ func (sc *ServerConn) Close() (c net.Conn, r *bufio.Reader) {
return
}
+// Close calls Hijack and then also closes the underlying connection
+func (sc *ServerConn) Close() os.Error {
+ c, _ := sc.Hijack()
+ if c != nil {
+ return c.Close()
+ }
+ return nil
+}
+
// Read returns the next request on the wire. An ErrPersistEOF is returned if
// it is gracefully determined that there are no more requests (e.g. after the
// first request on an HTTP/1.0 connection, or after a Connection:close on a
@@ -199,9 +208,9 @@ func (sc *ServerConn) Write(req *Request, resp *Response) os.Error {
}
// A ClientConn sends request and receives headers over an underlying
-// connection, while respecting the HTTP keepalive logic. ClientConn is not
-// responsible for closing the underlying connection. One must call Close to
-// regain control of that connection and deal with it as desired.
+// connection, while respecting the HTTP keepalive logic. ClientConn
+// supports hijacking the connection calling Hijack to
+// regain control of the underlying net.Conn and deal with it as desired.
type ClientConn struct {
lk sync.Mutex // read-write protects the following fields
c net.Conn
@@ -239,11 +248,11 @@ func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
return cc
}
-// Close detaches the ClientConn and returns the underlying connection as well
-// as the read-side bufio which may have some left over data. Close may be
+// Hijack detaches the ClientConn and returns the underlying connection as well
+// as the read-side bufio which may have some left over data. Hijack may be
// called before the user or Read have signaled the end of the keep-alive
-// logic. The user should not call Close while Read or Write is in progress.
-func (cc *ClientConn) Close() (c net.Conn, r *bufio.Reader) {
+// logic. The user should not call Hijack while Read or Write is in progress.
+func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) {
cc.lk.Lock()
defer cc.lk.Unlock()
c = cc.c
@@ -253,6 +262,15 @@ func (cc *ClientConn) Close() (c net.Conn, r *bufio.Reader) {
return
}
+// Close calls Hijack and then also closes the underlying connection
+func (cc *ClientConn) Close() os.Error {
+ c, _ := cc.Hijack()
+ if c != nil {
+ return c.Close()
+ }
+ return nil
+}
+
// Write writes a request. An ErrPersistEOF error is returned if the connection
// has been closed in an HTTP keepalive sense. If req.Close equals true, the
// keepalive connection is logically closed after this request and the opposing
diff --git a/src/pkg/http/proxy_test.go b/src/pkg/http/proxy_test.go
index 7050ef5ed..308bf44b4 100644
--- a/src/pkg/http/proxy_test.go
+++ b/src/pkg/http/proxy_test.go
@@ -16,9 +16,15 @@ var UseProxyTests = []struct {
host string
match bool
}{
- {"localhost", false}, // match completely
+ // Never proxy localhost:
+ {"localhost:80", false},
+ {"127.0.0.1", false},
+ {"127.0.0.2", false},
+ {"[::1]", false},
+ {"[::2]", true}, // not a loopback address
+
{"barbaz.net", false}, // match as .barbaz.net
- {"foobar.com:443", false}, // have a port but match
+ {"foobar.com", false}, // have a port but match
{"foofoobar.com", true}, // not match as a part of foobar.com
{"baz.com", true}, // not match as a part of barbaz.com
{"localhost.net", true}, // not match as suffix of address
@@ -29,19 +35,16 @@ var UseProxyTests = []struct {
func TestUseProxy(t *testing.T) {
oldenv := os.Getenv("NO_PROXY")
- no_proxy := "foobar.com, .barbaz.net , localhost"
- os.Setenv("NO_PROXY", no_proxy)
defer os.Setenv("NO_PROXY", oldenv)
+ no_proxy := "foobar.com, .barbaz.net"
+ os.Setenv("NO_PROXY", no_proxy)
+
tr := &Transport{}
for _, test := range UseProxyTests {
- if tr.useProxy(test.host) != test.match {
- if test.match {
- t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match)
- } else {
- t.Errorf("not expected: '%s' shouldn't match as '%s'", test.host, no_proxy)
- }
+ if tr.useProxy(test.host+":80") != test.match {
+ t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match)
}
}
}
diff --git a/src/pkg/http/request.go b/src/pkg/http/request.go
index d82894fab..b8e9a2142 100644
--- a/src/pkg/http/request.go
+++ b/src/pkg/http/request.go
@@ -4,9 +4,8 @@
// HTTP Request reading and parsing.
-// The http package implements parsing of HTTP requests, replies,
-// and URLs and provides an extensible HTTP server and a basic
-// HTTP client.
+// Package http implements parsing of HTTP requests, replies, and URLs and
+// provides an extensible HTTP server and a basic HTTP client.
package http
import (
@@ -25,12 +24,17 @@ import (
)
const (
- maxLineLength = 4096 // assumed <= bufio.defaultBufSize
- maxValueLength = 4096
- maxHeaderLines = 1024
- chunkSize = 4 << 10 // 4 KB chunks
+ maxLineLength = 4096 // assumed <= bufio.defaultBufSize
+ maxValueLength = 4096
+ maxHeaderLines = 1024
+ chunkSize = 4 << 10 // 4 KB chunks
+ defaultMaxMemory = 32 << 20 // 32 MB
)
+// ErrMissingFile is returned by FormFile when the provided file field name
+// is either not present in the request or not a file field.
+var ErrMissingFile = os.ErrorString("http: no such file")
+
// HTTP request parsing errors.
type ProtocolError struct {
os.ErrorString
@@ -65,9 +69,12 @@ var reqExcludeHeader = map[string]bool{
// A Request represents a parsed HTTP request header.
type Request struct {
- Method string // GET, POST, PUT, etc.
- RawURL string // The raw URL given in the request.
- URL *URL // Parsed URL.
+ Method string // GET, POST, PUT, etc.
+ RawURL string // The raw URL given in the request.
+ URL *URL // Parsed URL.
+
+ // The protocol version for incoming requests.
+ // Outgoing requests always use HTTP/1.1.
Proto string // "HTTP/1.0"
ProtoMajor int // 1
ProtoMinor int // 0
@@ -134,6 +141,10 @@ type Request struct {
// The parsed form. Only available after ParseForm is called.
Form map[string][]string
+ // The parsed multipart form, including file uploads.
+ // Only available after ParseMultipartForm is called.
+ MultipartForm *multipart.Form
+
// Trailer maps trailer keys to values. Like for Header, if the
// response has multiple trailer lines with the same key, they will be
// concatenated, delimited by commas.
@@ -163,9 +174,30 @@ func (r *Request) ProtoAtLeast(major, minor int) bool {
r.ProtoMajor == major && r.ProtoMinor >= minor
}
+// multipartByReader is a sentinel value.
+// Its presence in Request.MultipartForm indicates that parsing of the request
+// body has been handed off to a MultipartReader instead of ParseMultipartFrom.
+var multipartByReader = &multipart.Form{
+ Value: make(map[string][]string),
+ File: make(map[string][]*multipart.FileHeader),
+}
+
// MultipartReader returns a MIME multipart reader if this is a
// multipart/form-data POST request, else returns nil and an error.
+// Use this function instead of ParseMultipartForm to
+// process the request body as a stream.
func (r *Request) MultipartReader() (multipart.Reader, os.Error) {
+ if r.MultipartForm == multipartByReader {
+ return nil, os.NewError("http: MultipartReader called twice")
+ }
+ if r.MultipartForm != nil {
+ return nil, os.NewError("http: multipart handled by ParseMultipartForm")
+ }
+ r.MultipartForm = multipartByReader
+ return r.multipartReader()
+}
+
+func (r *Request) multipartReader() (multipart.Reader, os.Error) {
v := r.Header.Get("Content-Type")
if v == "" {
return nil, ErrNotMultipart
@@ -199,10 +231,14 @@ const defaultUserAgent = "Go http package"
// UserAgent (defaults to defaultUserAgent)
// Referer
// Header
+// Cookie
+// ContentLength
+// TransferEncoding
// Body
//
-// If Body is present, Write forces "Transfer-Encoding: chunked" as a header
-// and then closes Body when finished sending it.
+// If Body is present but Content-Length is <= 0, Write adds
+// "Transfer-Encoding: chunked" to the header. Body is closed after
+// it is sent.
func (req *Request) Write(w io.Writer) os.Error {
return req.write(w, false)
}
@@ -420,6 +456,29 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err os.Error) {
return n, cr.err
}
+// NewRequest returns a new Request given a method, URL, and optional body.
+func NewRequest(method, url string, body io.Reader) (*Request, os.Error) {
+ u, err := ParseURL(url)
+ if err != nil {
+ return nil, err
+ }
+ rc, ok := body.(io.ReadCloser)
+ if !ok && body != nil {
+ rc = ioutil.NopCloser(body)
+ }
+ req := &Request{
+ Method: method,
+ URL: u,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(Header),
+ Body: rc,
+ Host: u.Host,
+ }
+ return req, nil
+}
+
// ReadRequest reads and parses a request from b.
func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) {
@@ -549,7 +608,9 @@ func parseQuery(m map[string][]string, query string) (err os.Error) {
return err
}
-// ParseForm parses the request body as a form for POST requests, or the raw query for GET requests.
+// ParseForm parses the raw query.
+// For POST requests, it also parses the request body as a form.
+// ParseMultipartForm calls ParseForm automatically.
// It is idempotent.
func (r *Request) ParseForm() (err os.Error) {
if r.Form != nil {
@@ -567,18 +628,23 @@ func (r *Request) ParseForm() (err os.Error) {
ct := r.Header.Get("Content-Type")
switch strings.Split(ct, ";", 2)[0] {
case "text/plain", "application/x-www-form-urlencoded", "":
- b, e := ioutil.ReadAll(r.Body)
+ const maxFormSize = int64(10 << 20) // 10 MB is a lot of text.
+ b, e := ioutil.ReadAll(io.LimitReader(r.Body, maxFormSize+1))
if e != nil {
if err == nil {
err = e
}
break
}
+ if int64(len(b)) > maxFormSize {
+ return os.NewError("http: POST too large")
+ }
e = parseQuery(r.Form, string(b))
if err == nil {
err = e
}
- // TODO(dsymonds): Handle multipart/form-data
+ case "multipart/form-data":
+ // handled by ParseMultipartForm
default:
return &badStringError{"unknown Content-Type", ct}
}
@@ -586,11 +652,50 @@ func (r *Request) ParseForm() (err os.Error) {
return err
}
+// ParseMultipartForm parses a request body as multipart/form-data.
+// The whole request body is parsed and up to a total of maxMemory bytes of
+// its file parts are stored in memory, with the remainder stored on
+// disk in temporary files.
+// ParseMultipartForm calls ParseForm if necessary.
+// After one call to ParseMultipartForm, subsequent calls have no effect.
+func (r *Request) ParseMultipartForm(maxMemory int64) os.Error {
+ if r.Form == nil {
+ err := r.ParseForm()
+ if err != nil {
+ return err
+ }
+ }
+ if r.MultipartForm != nil {
+ return nil
+ }
+ if r.MultipartForm == multipartByReader {
+ return os.NewError("http: multipart handled by MultipartReader")
+ }
+
+ mr, err := r.multipartReader()
+ if err == ErrNotMultipart {
+ return nil
+ } else if err != nil {
+ return err
+ }
+
+ f, err := mr.ReadForm(maxMemory)
+ if err != nil {
+ return err
+ }
+ for k, v := range f.Value {
+ r.Form[k] = append(r.Form[k], v...)
+ }
+ r.MultipartForm = f
+
+ return nil
+}
+
// FormValue returns the first value for the named component of the query.
-// FormValue calls ParseForm if necessary.
+// FormValue calls ParseMultipartForm and ParseForm if necessary.
func (r *Request) FormValue(key string) string {
if r.Form == nil {
- r.ParseForm()
+ r.ParseMultipartForm(defaultMaxMemory)
}
if vs := r.Form[key]; len(vs) > 0 {
return vs[0]
@@ -598,6 +703,25 @@ func (r *Request) FormValue(key string) string {
return ""
}
+// FormFile returns the first file for the provided form key.
+// FormFile calls ParseMultipartForm and ParseForm if necessary.
+func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, os.Error) {
+ if r.MultipartForm == multipartByReader {
+ return nil, nil, os.NewError("http: multipart handled by MultipartReader")
+ }
+ if r.MultipartForm == nil {
+ err := r.ParseMultipartForm(defaultMaxMemory)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ if fhs := r.MultipartForm.File[key]; len(fhs) > 0 {
+ f, err := fhs[0].Open()
+ return f, fhs[0], err
+ }
+ return nil, nil, ErrMissingFile
+}
+
func (r *Request) expectsContinue() bool {
return strings.ToLower(r.Header.Get("Expect")) == "100-continue"
}
diff --git a/src/pkg/http/request_test.go b/src/pkg/http/request_test.go
index 19083adf6..f982471d8 100644
--- a/src/pkg/http/request_test.go
+++ b/src/pkg/http/request_test.go
@@ -10,6 +10,8 @@ import (
. "http"
"http/httptest"
"io"
+ "io/ioutil"
+ "mime/multipart"
"os"
"reflect"
"regexp"
@@ -82,7 +84,7 @@ func TestPostQuery(t *testing.T) {
req.Header = Header{
"Content-Type": {"application/x-www-form-urlencoded; boo!"},
}
- req.Body = nopCloser{strings.NewReader("z=post&both=y")}
+ req.Body = ioutil.NopCloser(strings.NewReader("z=post&both=y"))
if q := req.FormValue("q"); q != "foo" {
t.Errorf(`req.FormValue("q") = %q, want "foo"`, q)
}
@@ -115,7 +117,7 @@ func TestPostContentTypeParsing(t *testing.T) {
req := &Request{
Method: "POST",
Header: Header(test.contentType),
- Body: nopCloser{bytes.NewBufferString("body")},
+ Body: ioutil.NopCloser(bytes.NewBufferString("body")),
}
err := req.ParseForm()
if !test.error && err != nil {
@@ -131,7 +133,7 @@ func TestMultipartReader(t *testing.T) {
req := &Request{
Method: "POST",
Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}},
- Body: nopCloser{new(bytes.Buffer)},
+ Body: ioutil.NopCloser(new(bytes.Buffer)),
}
multipart, err := req.MultipartReader()
if multipart == nil {
@@ -170,9 +172,115 @@ func TestRedirect(t *testing.T) {
}
}
-// TODO: stop copy/pasting this around. move to io/ioutil?
-type nopCloser struct {
- io.Reader
+func TestMultipartRequest(t *testing.T) {
+ // Test that we can read the values and files of a
+ // multipart request with FormValue and FormFile,
+ // and that ParseMultipartForm can be called multiple times.
+ req := newTestMultipartRequest(t)
+ if err := req.ParseMultipartForm(25); err != nil {
+ t.Fatal("ParseMultipartForm first call:", err)
+ }
+ defer req.MultipartForm.RemoveAll()
+ validateTestMultipartContents(t, req, false)
+ if err := req.ParseMultipartForm(25); err != nil {
+ t.Fatal("ParseMultipartForm second call:", err)
+ }
+ validateTestMultipartContents(t, req, false)
+}
+
+func TestMultipartRequestAuto(t *testing.T) {
+ // Test that FormValue and FormFile automatically invoke
+ // ParseMultipartForm and return the right values.
+ req := newTestMultipartRequest(t)
+ defer func() {
+ if req.MultipartForm != nil {
+ req.MultipartForm.RemoveAll()
+ }
+ }()
+ validateTestMultipartContents(t, req, true)
+}
+
+func newTestMultipartRequest(t *testing.T) *Request {
+ b := bytes.NewBufferString(strings.Replace(message, "\n", "\r\n", -1))
+ req, err := NewRequest("POST", "/", b)
+ if err != nil {
+ t.Fatalf("NewRequest:", err)
+ }
+ ctype := fmt.Sprintf(`multipart/form-data; boundary="%s"`, boundary)
+ req.Header.Set("Content-type", ctype)
+ return req
+}
+
+func validateTestMultipartContents(t *testing.T, req *Request, allMem bool) {
+ if g, e := req.FormValue("texta"), textaValue; g != e {
+ t.Errorf("texta value = %q, want %q", g, e)
+ }
+ if g, e := req.FormValue("texta"), textaValue; g != e {
+ t.Errorf("texta value = %q, want %q", g, e)
+ }
+
+ assertMem := func(n string, fd multipart.File) {
+ if _, ok := fd.(*os.File); ok {
+ t.Error(n, " is *os.File, should not be")
+ }
+ }
+ fd := testMultipartFile(t, req, "filea", "filea.txt", fileaContents)
+ assertMem("filea", fd)
+ fd = testMultipartFile(t, req, "fileb", "fileb.txt", filebContents)
+ if allMem {
+ assertMem("fileb", fd)
+ } else {
+ if _, ok := fd.(*os.File); !ok {
+ t.Errorf("fileb has unexpected underlying type %T", fd)
+ }
+ }
+}
+
+func testMultipartFile(t *testing.T, req *Request, key, expectFilename, expectContent string) multipart.File {
+ f, fh, err := req.FormFile(key)
+ if err != nil {
+ t.Fatalf("FormFile(%q):", key, err)
+ }
+ if fh.Filename != expectFilename {
+ t.Errorf("filename = %q, want %q", fh.Filename, expectFilename)
+ }
+ var b bytes.Buffer
+ _, err = io.Copy(&b, f)
+ if err != nil {
+ t.Fatal("copying contents:", err)
+ }
+ if g := b.String(); g != expectContent {
+ t.Errorf("contents = %q, want %q", g, expectContent)
+ }
+ return f
}
-func (nopCloser) Close() os.Error { return nil }
+const (
+ fileaContents = "This is a test file."
+ filebContents = "Another test file."
+ textaValue = "foo"
+ textbValue = "bar"
+ boundary = `MyBoundary`
+)
+
+const message = `
+--MyBoundary
+Content-Disposition: form-data; name="filea"; filename="filea.txt"
+Content-Type: text/plain
+
+` + fileaContents + `
+--MyBoundary
+Content-Disposition: form-data; name="fileb"; filename="fileb.txt"
+Content-Type: text/plain
+
+` + filebContents + `
+--MyBoundary
+Content-Disposition: form-data; name="texta"
+
+` + textaValue + `
+--MyBoundary
+Content-Disposition: form-data; name="textb"
+
+` + textbValue + `
+--MyBoundary--
+`
diff --git a/src/pkg/http/requestwrite_test.go b/src/pkg/http/requestwrite_test.go
index 726baa266..bb000c701 100644
--- a/src/pkg/http/requestwrite_test.go
+++ b/src/pkg/http/requestwrite_test.go
@@ -6,7 +6,10 @@ package http
import (
"bytes"
+ "io"
"io/ioutil"
+ "os"
+ "strings"
"testing"
)
@@ -133,6 +136,41 @@ var reqWriteTests = []reqWriteTest{
"Transfer-Encoding: chunked\r\n\r\n" +
"6\r\nabcdef\r\n0\r\n\r\n",
},
+
+ // HTTP/1.1 POST with Content-Length, no chunking
+ {
+ Request{
+ Method: "POST",
+ URL: &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: true,
+ ContentLength: 6,
+ },
+
+ []byte("abcdef"),
+
+ "POST /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Connection: close\r\n" +
+ "Content-Length: 6\r\n" +
+ "\r\n" +
+ "abcdef",
+
+ "POST http://www.google.com/search HTTP/1.1\r\n" +
+ "User-Agent: Go http package\r\n" +
+ "Connection: close\r\n" +
+ "Content-Length: 6\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+
// default to HTTP/1.1
{
Request{
@@ -189,3 +227,26 @@ func TestRequestWrite(t *testing.T) {
}
}
}
+
+type closeChecker struct {
+ io.Reader
+ closed bool
+}
+
+func (rc *closeChecker) Close() os.Error {
+ rc.closed = true
+ return nil
+}
+
+// TestRequestWriteClosesBody tests that Request.Write does close its request.Body.
+// It also indirectly tests NewRequest and that it doesn't wrap an existing Closer
+// inside a NopCloser.
+func TestRequestWriteClosesBody(t *testing.T) {
+ rc := &closeChecker{Reader: strings.NewReader("my body")}
+ req, _ := NewRequest("GET", "http://foo.com/", rc)
+ buf := new(bytes.Buffer)
+ req.Write(buf)
+ if !rc.closed {
+ t.Error("body not closed after write")
+ }
+}
diff --git a/src/pkg/http/response_test.go b/src/pkg/http/response_test.go
index 314f05b36..9e77c20c4 100644
--- a/src/pkg/http/response_test.go
+++ b/src/pkg/http/response_test.go
@@ -7,8 +7,12 @@ package http
import (
"bufio"
"bytes"
+ "compress/gzip"
+ "crypto/rand"
"fmt"
+ "os"
"io"
+ "io/ioutil"
"reflect"
"testing"
)
@@ -117,7 +121,9 @@ var respTests = []respTest{
"Transfer-Encoding: chunked\r\n" +
"\r\n" +
"0a\r\n" +
- "Body here\n" +
+ "Body here\n\r\n" +
+ "09\r\n" +
+ "continued\r\n" +
"0\r\n" +
"\r\n",
@@ -134,7 +140,7 @@ var respTests = []respTest{
TransferEncoding: []string{"chunked"},
},
- "Body here\n",
+ "Body here\ncontinued",
},
// Chunked response with Content-Length.
@@ -186,6 +192,29 @@ var respTests = []respTest{
"",
},
+ // explicit Content-Length of 0.
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ RequestMethod: "GET",
+ Header: Header{
+ "Content-Length": {"0"},
+ },
+ Close: false,
+ ContentLength: 0,
+ },
+
+ "",
+ },
+
// Status line without a Reason-Phrase, but trailing space.
// (permitted by RFC 2616)
{
@@ -250,9 +279,107 @@ func TestReadResponse(t *testing.T) {
}
}
+var readResponseCloseInMiddleTests = []struct {
+ chunked, compressed bool
+}{
+ {false, false},
+ {true, false},
+ {true, true},
+}
+
+// TestReadResponseCloseInMiddle tests that closing a body after
+// reading only part of its contents advances the read to the end of
+// the request, right up until the next request.
+func TestReadResponseCloseInMiddle(t *testing.T) {
+ for _, test := range readResponseCloseInMiddleTests {
+ fatalf := func(format string, args ...interface{}) {
+ args = append([]interface{}{test.chunked, test.compressed}, args...)
+ t.Fatalf("on test chunked=%v, compressed=%v: "+format, args...)
+ }
+ checkErr := func(err os.Error, msg string) {
+ if err == nil {
+ return
+ }
+ fatalf(msg+": %v", err)
+ }
+ var buf bytes.Buffer
+ buf.WriteString("HTTP/1.1 200 OK\r\n")
+ if test.chunked {
+ buf.WriteString("Transfer-Encoding: chunked\r\n")
+ } else {
+ buf.WriteString("Content-Length: 1000000\r\n")
+ }
+ var wr io.Writer = &buf
+ if test.chunked {
+ wr = &chunkedWriter{wr}
+ }
+ if test.compressed {
+ buf.WriteString("Content-Encoding: gzip\r\n")
+ var err os.Error
+ wr, err = gzip.NewWriter(wr)
+ checkErr(err, "gzip.NewWriter")
+ }
+ buf.WriteString("\r\n")
+
+ chunk := bytes.Repeat([]byte{'x'}, 1000)
+ for i := 0; i < 1000; i++ {
+ if test.compressed {
+ // Otherwise this compresses too well.
+ _, err := io.ReadFull(rand.Reader, chunk)
+ checkErr(err, "rand.Reader ReadFull")
+ }
+ wr.Write(chunk)
+ }
+ if test.compressed {
+ err := wr.(*gzip.Compressor).Close()
+ checkErr(err, "compressor close")
+ }
+ if test.chunked {
+ buf.WriteString("0\r\n\r\n")
+ }
+ buf.WriteString("Next Request Here")
+
+ bufr := bufio.NewReader(&buf)
+ resp, err := ReadResponse(bufr, "GET")
+ checkErr(err, "ReadResponse")
+ expectedLength := int64(-1)
+ if !test.chunked {
+ expectedLength = 1000000
+ }
+ if resp.ContentLength != expectedLength {
+ fatalf("expected response length %d, got %d", expectedLength, resp.ContentLength)
+ }
+ if resp.Body == nil {
+ fatalf("nil body")
+ }
+ if test.compressed {
+ gzReader, err := gzip.NewReader(resp.Body)
+ checkErr(err, "gzip.NewReader")
+ resp.Body = &readFirstCloseBoth{gzReader, resp.Body}
+ }
+
+ rbuf := make([]byte, 2500)
+ n, err := io.ReadFull(resp.Body, rbuf)
+ checkErr(err, "2500 byte ReadFull")
+ if n != 2500 {
+ fatalf("ReadFull only read %d bytes", n)
+ }
+ if test.compressed == false && !bytes.Equal(bytes.Repeat([]byte{'x'}, 2500), rbuf) {
+ fatalf("ReadFull didn't read 2500 'x'; got %q", string(rbuf))
+ }
+ resp.Body.Close()
+
+ rest, err := ioutil.ReadAll(bufr)
+ checkErr(err, "ReadAll on remainder")
+ if e, g := "Next Request Here", string(rest); e != g {
+ fatalf("for chunked=%v remainder = %q, expected %q", g, e)
+ }
+ }
+}
+
func diff(t *testing.T, prefix string, have, want interface{}) {
- hv := reflect.NewValue(have).Elem()
- wv := reflect.NewValue(want).Elem()
+ hv := reflect.ValueOf(have).Elem()
+ wv := reflect.ValueOf(want).Elem()
if hv.Type() != wv.Type() {
t.Errorf("%s: type mismatch %v vs %v", prefix, hv.Type(), wv.Type())
}
diff --git a/src/pkg/http/reverseproxy.go b/src/pkg/http/reverseproxy.go
new file mode 100644
index 000000000..e4ce1e34c
--- /dev/null
+++ b/src/pkg/http/reverseproxy.go
@@ -0,0 +1,100 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP reverse proxy handler
+
+package http
+
+import (
+ "io"
+ "log"
+ "net"
+ "strings"
+)
+
+// ReverseProxy is an HTTP Handler that takes an incoming request and
+// sends it to another server, proxying the response back to the
+// client.
+type ReverseProxy struct {
+ // Director must be a function which modifies
+ // the request into a new request to be sent
+ // using Transport. Its response is then copied
+ // back to the original client unmodified.
+ Director func(*Request)
+
+ // The Transport used to perform proxy requests.
+ // If nil, DefaultTransport is used.
+ Transport RoundTripper
+}
+
+func singleJoiningSlash(a, b string) string {
+ aslash := strings.HasSuffix(a, "/")
+ bslash := strings.HasPrefix(b, "/")
+ switch {
+ case aslash && bslash:
+ return a + b[1:]
+ case !aslash && !bslash:
+ return a + "/" + b
+ }
+ return a + b
+}
+
+// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
+// URLs to the scheme, host, and base path provided in target. If the
+// target's path is "/base" and the incoming request was for "/dir",
+// the target request will be for /base/dir.
+func NewSingleHostReverseProxy(target *URL) *ReverseProxy {
+ director := func(req *Request) {
+ req.URL.Scheme = target.Scheme
+ req.URL.Host = target.Host
+ req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
+ if q := req.URL.RawQuery; q != "" {
+ req.URL.RawPath = req.URL.Path + "?" + q
+ } else {
+ req.URL.RawPath = req.URL.Path
+ }
+ req.URL.RawQuery = target.RawQuery
+ }
+ return &ReverseProxy{Director: director}
+}
+
+func (p *ReverseProxy) ServeHTTP(rw ResponseWriter, req *Request) {
+ transport := p.Transport
+ if transport == nil {
+ transport = DefaultTransport
+ }
+
+ outreq := new(Request)
+ *outreq = *req // includes shallow copies of maps, but okay
+
+ p.Director(outreq)
+ outreq.Proto = "HTTP/1.1"
+ outreq.ProtoMajor = 1
+ outreq.ProtoMinor = 1
+ outreq.Close = false
+
+ if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
+ outreq.Header.Set("X-Forwarded-For", clientIp)
+ }
+
+ res, err := transport.RoundTrip(outreq)
+ if err != nil {
+ log.Printf("http: proxy error: %v", err)
+ rw.WriteHeader(StatusInternalServerError)
+ return
+ }
+
+ hdr := rw.Header()
+ for k, vv := range res.Header {
+ for _, v := range vv {
+ hdr.Add(k, v)
+ }
+ }
+
+ rw.WriteHeader(res.StatusCode)
+
+ if res.Body != nil {
+ io.Copy(rw, res.Body)
+ }
+}
diff --git a/src/pkg/http/reverseproxy_test.go b/src/pkg/http/reverseproxy_test.go
new file mode 100644
index 000000000..8cf7705d7
--- /dev/null
+++ b/src/pkg/http/reverseproxy_test.go
@@ -0,0 +1,50 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Reverse proxy tests.
+
+package http_test
+
+import (
+ . "http"
+ "http/httptest"
+ "io/ioutil"
+ "testing"
+)
+
+func TestReverseProxy(t *testing.T) {
+ const backendResponse = "I am the backend"
+ const backendStatus = 404
+ backend := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Header.Get("X-Forwarded-For") == "" {
+ t.Errorf("didn't get X-Forwarded-For header")
+ }
+ w.Header().Set("X-Foo", "bar")
+ w.WriteHeader(backendStatus)
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := ParseURL(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ res, _, err := Get(frontend.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
+ t.Errorf("got X-Foo %q; expected %q", g, e)
+ }
+ bodyBytes, _ := ioutil.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
diff --git a/src/pkg/http/serve_test.go b/src/pkg/http/serve_test.go
index 0142dead9..c3c7b8d33 100644
--- a/src/pkg/http/serve_test.go
+++ b/src/pkg/http/serve_test.go
@@ -247,7 +247,7 @@ func TestServerTimeouts(t *testing.T) {
server := &Server{Handler: handler, ReadTimeout: 0.25 * second, WriteTimeout: 0.25 * second}
go server.Serve(l)
- url := fmt.Sprintf("http://localhost:%d/", addr.Port)
+ url := fmt.Sprintf("http://%s/", addr)
// Hit the HTTP server successfully.
tr := &Transport{DisableKeepAlives: true} // they interfere with this test
@@ -265,7 +265,7 @@ func TestServerTimeouts(t *testing.T) {
// Slow client that should timeout.
t1 := time.Nanoseconds()
- conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", addr.Port))
+ conn, err := net.Dial("tcp", addr.String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
@@ -588,7 +588,7 @@ func TestServerExpect(t *testing.T) {
sendf := func(format string, args ...interface{}) {
_, err := fmt.Fprintf(conn, format, args...)
if err != nil {
- t.Fatalf("Error writing %q: %v", format, err)
+ t.Fatalf("On test %#v, error writing %q: %v", test, format, err)
}
}
go func() {
@@ -616,3 +616,100 @@ func TestServerExpect(t *testing.T) {
runTest(test)
}
}
+
+func TestServerConsumesRequestBody(t *testing.T) {
+ log := make(chan string, 100)
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ log <- "got_request"
+ w.WriteHeader(StatusOK)
+ log <- "wrote_header"
+ }))
+ defer ts.Close()
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer conn.Close()
+
+ bufr := bufio.NewReader(conn)
+ gotres := make(chan bool)
+ go func() {
+ line, err := bufr.ReadString('\n')
+ if err != nil {
+ t.Fatal(err)
+ }
+ log <- line
+ gotres <- true
+ }()
+
+ size := 1 << 20
+ log <- "writing_request"
+ fmt.Fprintf(conn, "POST / HTTP/1.0\r\nContent-Length: %d\r\n\r\n", size)
+ time.Sleep(25e6) // give server chance to misbehave & speak out of turn
+ log <- "slept_after_req_headers"
+ conn.Write([]byte(strings.Repeat("a", size)))
+
+ <-gotres
+ expected := []string{
+ "writing_request", "got_request",
+ "slept_after_req_headers", "wrote_header",
+ "HTTP/1.0 200 OK\r\n"}
+ for step, e := range expected {
+ if g := <-log; e != g {
+ t.Errorf("on step %d expected %q, got %q", step, e, g)
+ }
+ }
+}
+
+func TestTimeoutHandler(t *testing.T) {
+ sendHi := make(chan bool, 1)
+ writeErrors := make(chan os.Error, 1)
+ sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-sendHi
+ _, werr := w.Write([]byte("hi"))
+ writeErrors <- werr
+ })
+ timeout := make(chan int64, 1) // write to this to force timeouts
+ ts := httptest.NewServer(NewTestTimeoutHandler(sayHi, timeout))
+ defer ts.Close()
+
+ // Succeed without timing out:
+ sendHi <- true
+ res, _, err := Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusOK; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ := ioutil.ReadAll(res.Body)
+ if g, e := string(body), "hi"; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+ if g := <-writeErrors; g != nil {
+ t.Errorf("got unexpected Write error on first request: %v", g)
+ }
+
+ // Times out:
+ timeout <- 1
+ res, _, err = Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ = ioutil.ReadAll(res.Body)
+ if !strings.Contains(string(body), "<title>Timeout</title>") {
+ t.Errorf("expected timeout body; got %q", string(body))
+ }
+
+ // Now make the previously-timed out handler speak again,
+ // which verifies the panic is handled:
+ sendHi <- true
+ if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
+ t.Errorf("expected Write error of %v; got %v", e, g)
+ }
+}
diff --git a/src/pkg/http/server.go b/src/pkg/http/server.go
index 3291de101..96d2cb638 100644
--- a/src/pkg/http/server.go
+++ b/src/pkg/http/server.go
@@ -22,6 +22,7 @@ import (
"path"
"strconv"
"strings"
+ "sync"
"time"
)
@@ -141,9 +142,13 @@ func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) {
type expectContinueReader struct {
resp *response
readCloser io.ReadCloser
+ closed bool
}
func (ecr *expectContinueReader) Read(p []byte) (n int, err os.Error) {
+ if ecr.closed {
+ return 0, os.NewError("http: Read after Close on request Body")
+ }
if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked {
ecr.resp.wroteContinue = true
io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n")
@@ -153,6 +158,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err os.Error) {
}
func (ecr *expectContinueReader) Close() os.Error {
+ ecr.closed = true
return ecr.readCloser.Close()
}
@@ -196,6 +202,16 @@ func (w *response) WriteHeader(code int) {
log.Print("http: multiple response.WriteHeader calls")
return
}
+
+ // Per RFC 2616, we should consume the request body before
+ // replying, if the handler hasn't already done so.
+ if w.req.ContentLength != 0 {
+ ecr, isExpecter := w.req.Body.(*expectContinueReader)
+ if !isExpecter || ecr.resp.wroteContinue {
+ w.req.Body.Close()
+ }
+ }
+
w.wroteHeader = true
w.status = code
if code == StatusNotModified {
@@ -407,6 +423,9 @@ func (w *response) finishRequest() {
}
w.conn.buf.Flush()
w.req.Body.Close()
+ if w.req.MultipartForm != nil {
+ w.req.MultipartForm.RemoveAll()
+ }
if w.contentLength != -1 && w.contentLength != w.written {
// Did not write enough. Avoid getting out of sync.
@@ -883,3 +902,89 @@ func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Han
tlsListener := tls.NewListener(conn, config)
return Serve(tlsListener, handler)
}
+
+// TimeoutHandler returns a Handler that runs h with the given time limit.
+//
+// The new Handler calls h.ServeHTTP to handle each request, but if a
+// call runs for more than ns nanoseconds, the handler responds with
+// a 503 Service Unavailable error and the given message in its body.
+// (If msg is empty, a suitable default message will be sent.)
+// After such a timeout, writes by h to its ResponseWriter will return
+// ErrHandlerTimeout.
+func TimeoutHandler(h Handler, ns int64, msg string) Handler {
+ f := func() <-chan int64 {
+ return time.After(ns)
+ }
+ return &timeoutHandler{h, f, msg}
+}
+
+// ErrHandlerTimeout is returned on ResponseWriter Write calls
+// in handlers which have timed out.
+var ErrHandlerTimeout = os.NewError("http: Handler timeout")
+
+type timeoutHandler struct {
+ handler Handler
+ timeout func() <-chan int64 // returns channel producing a timeout
+ body string
+}
+
+func (h *timeoutHandler) errorBody() string {
+ if h.body != "" {
+ return h.body
+ }
+ return "<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>"
+}
+
+func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ done := make(chan bool)
+ tw := &timeoutWriter{w: w}
+ go func() {
+ h.handler.ServeHTTP(tw, r)
+ done <- true
+ }()
+ select {
+ case <-done:
+ return
+ case <-h.timeout():
+ tw.mu.Lock()
+ defer tw.mu.Unlock()
+ if !tw.wroteHeader {
+ tw.w.WriteHeader(StatusServiceUnavailable)
+ tw.w.Write([]byte(h.errorBody()))
+ }
+ tw.timedOut = true
+ }
+}
+
+type timeoutWriter struct {
+ w ResponseWriter
+
+ mu sync.Mutex
+ timedOut bool
+ wroteHeader bool
+}
+
+func (tw *timeoutWriter) Header() Header {
+ return tw.w.Header()
+}
+
+func (tw *timeoutWriter) Write(p []byte) (int, os.Error) {
+ tw.mu.Lock()
+ timedOut := tw.timedOut
+ tw.mu.Unlock()
+ if timedOut {
+ return 0, ErrHandlerTimeout
+ }
+ return tw.w.Write(p)
+}
+
+func (tw *timeoutWriter) WriteHeader(code int) {
+ tw.mu.Lock()
+ if tw.timedOut || tw.wroteHeader {
+ tw.mu.Unlock()
+ return
+ }
+ tw.wroteHeader = true
+ tw.mu.Unlock()
+ tw.w.WriteHeader(code)
+}
diff --git a/src/pkg/http/transfer.go b/src/pkg/http/transfer.go
index 41614f144..98c32bab6 100644
--- a/src/pkg/http/transfer.go
+++ b/src/pkg/http/transfer.go
@@ -7,6 +7,7 @@ package http
import (
"bufio"
"io"
+ "io/ioutil"
"os"
"strconv"
"strings"
@@ -447,17 +448,10 @@ func (b *body) Close() os.Error {
return nil
}
- trashBuf := make([]byte, 1024) // local for thread safety
- for {
- _, err := b.Read(trashBuf)
- if err == nil {
- continue
- }
- if err == os.EOF {
- break
- }
+ if _, err := io.Copy(ioutil.Discard, b); err != nil {
return err
}
+
if b.hdr == nil { // not reading trailer
return nil
}
diff --git a/src/pkg/http/transport.go b/src/pkg/http/transport.go
index 7fa37af3b..73a2c2191 100644
--- a/src/pkg/http/transport.go
+++ b/src/pkg/http/transport.go
@@ -6,6 +6,7 @@ package http
import (
"bufio"
+ "bytes"
"compress/gzip"
"crypto/tls"
"encoding/base64"
@@ -217,6 +218,9 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
conn, err := net.Dial("tcp", cm.addr())
if err != nil {
+ if cm.proxyURL != nil {
+ err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err)
+ }
return nil, err
}
@@ -288,10 +292,28 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
// useProxy returns true if requests to addr should use a proxy,
// according to the NO_PROXY or no_proxy environment variable.
+// addr is always a canonicalAddr with a host and port.
func (t *Transport) useProxy(addr string) bool {
if len(addr) == 0 {
return true
}
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return false
+ }
+ if host == "localhost" {
+ return false
+ }
+ if ip := net.ParseIP(host); ip != nil {
+ if ip4 := ip.To4(); ip4 != nil && ip4[0] == 127 {
+ // 127.0.0.0/8 loopback isn't proxied.
+ return false
+ }
+ if bytes.Equal(ip, net.IPv6loopback) {
+ return false
+ }
+ }
+
no_proxy := t.getenvEitherCase("NO_PROXY")
if no_proxy == "*" {
return false
@@ -510,12 +532,13 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
re.res.Header.Del("Content-Encoding")
re.res.Header.Del("Content-Length")
re.res.ContentLength = -1
- var err os.Error
- re.res.Body, err = gzip.NewReader(re.res.Body)
+ esb := re.res.Body.(*bodyEOFSignal)
+ gzReader, err := gzip.NewReader(esb.body)
if err != nil {
pc.close()
return nil, err
}
+ esb.body = &readFirstCloseBoth{gzReader, esb.body}
}
return re.res, re.err
@@ -554,7 +577,7 @@ func responseIsKeepAlive(res *Response) bool {
func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) {
resp, err = ReadResponse(r, requestMethod)
if err == nil && resp.ContentLength != 0 {
- resp.Body = &bodyEOFSignal{resp.Body, nil}
+ resp.Body = &bodyEOFSignal{body: resp.Body}
}
return
}
@@ -563,12 +586,16 @@ func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Res
// once, right before the final Read() or Close() call returns, but after
// EOF has been seen.
type bodyEOFSignal struct {
- body io.ReadCloser
- fn func()
+ body io.ReadCloser
+ fn func()
+ isClosed bool
}
func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) {
n, err = es.body.Read(p)
+ if es.isClosed && n > 0 {
+ panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725")
+ }
if err == os.EOF && es.fn != nil {
es.fn()
es.fn = nil
@@ -577,6 +604,7 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) {
}
func (es *bodyEOFSignal) Close() (err os.Error) {
+ es.isClosed = true
err = es.body.Close()
if err == nil && es.fn != nil {
es.fn()
@@ -584,3 +612,19 @@ func (es *bodyEOFSignal) Close() (err os.Error) {
}
return
}
+
+type readFirstCloseBoth struct {
+ io.ReadCloser
+ io.Closer
+}
+
+func (r *readFirstCloseBoth) Close() os.Error {
+ if err := r.ReadCloser.Close(); err != nil {
+ r.Closer.Close()
+ return err
+ }
+ if err := r.Closer.Close(); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/src/pkg/http/transport_test.go b/src/pkg/http/transport_test.go
index f83deedfc..a32ac4c4f 100644
--- a/src/pkg/http/transport_test.go
+++ b/src/pkg/http/transport_test.go
@@ -9,11 +9,14 @@ package http_test
import (
"bytes"
"compress/gzip"
+ "crypto/rand"
"fmt"
. "http"
"http/httptest"
+ "io"
"io/ioutil"
"os"
+ "strconv"
"testing"
"time"
)
@@ -179,35 +182,47 @@ func TestTransportIdleCacheKeys(t *testing.T) {
}
func TestTransportMaxPerHostIdleConns(t *testing.T) {
- ch := make(chan string)
+ resch := make(chan string)
+ gotReq := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- w.Write([]byte(<-ch))
+ gotReq <- true
+ msg := <-resch
+ _, err := w.Write([]byte(msg))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
}))
defer ts.Close()
maxIdleConns := 2
tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConns}
c := &Client{Transport: tr}
- // Start 3 outstanding requests (will hang until we write to
- // ch)
+ // Start 3 outstanding requests and wait for the server to get them.
+ // Their responses will hang until we we write to resch, though.
donech := make(chan bool)
doReq := func() {
resp, _, err := c.Get(ts.URL)
if err != nil {
t.Error(err)
}
- ioutil.ReadAll(resp.Body)
+ _, err = ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("ReadAll: %v", err)
+ }
donech <- true
}
go doReq()
+ <-gotReq
go doReq()
+ <-gotReq
go doReq()
+ <-gotReq
if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
}
- ch <- "res1"
+ resch <- "res1"
<-donech
keys := tr.IdleConnKeysForTesting()
if e, g := 1, len(keys); e != g {
@@ -221,13 +236,13 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
t.Errorf("after first response, expected %d idle conns; got %d", e, g)
}
- ch <- "res2"
+ resch <- "res2"
<-donech
if e, g := 2, tr.IdleConnCountForTesting(cacheKey); e != g {
t.Errorf("after second response, expected %d idle conns; got %d", e, g)
}
- ch <- "res3"
+ resch <- "res3"
<-donech
if e, g := maxIdleConns, tr.IdleConnCountForTesting(cacheKey); e != g {
t.Errorf("after third response, still expected %d idle conns; got %d", e, g)
@@ -355,32 +370,80 @@ func TestTransportNilURL(t *testing.T) {
func TestTransportGzip(t *testing.T) {
const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- if g, e := r.Header.Get("Accept-Encoding"), "gzip"; g != e {
+ const nRandBytes = 1024 * 1024
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
t.Errorf("Accept-Encoding = %q, want %q", g, e)
}
- w.Header().Set("Content-Encoding", "gzip")
+ rw.Header().Set("Content-Encoding", "gzip")
+
+ var w io.Writer = rw
+ var buf bytes.Buffer
+ if req.FormValue("chunked") == "0" {
+ w = &buf
+ defer io.Copy(rw, &buf)
+ defer func() {
+ rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
+ }()
+ }
gz, _ := gzip.NewWriter(w)
- defer gz.Close()
gz.Write([]byte(testString))
-
+ if req.FormValue("body") == "large" {
+ io.Copyn(gz, rand.Reader, nRandBytes)
+ }
+ gz.Close()
}))
defer ts.Close()
- c := &Client{Transport: &Transport{}}
- res, _, err := c.Get(ts.URL)
- if err != nil {
- t.Fatal(err)
- }
- body, err := ioutil.ReadAll(res.Body)
- if err != nil {
- t.Fatal(err)
- }
- if g, e := string(body), testString; g != e {
- t.Fatalf("body = %q; want %q", g, e)
- }
- if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
- t.Fatalf("Content-Encoding = %q; want %q", g, e)
+ for _, chunked := range []string{"1", "0"} {
+ c := &Client{Transport: &Transport{}}
+
+ // First fetch something large, but only read some of it.
+ res, _, err := c.Get(ts.URL + "?body=large&chunked=" + chunked)
+ if err != nil {
+ t.Fatalf("large get: %v", err)
+ }
+ buf := make([]byte, len(testString))
+ n, err := io.ReadFull(res.Body, buf)
+ if err != nil {
+ t.Fatalf("partial read of large response: size=%d, %v", n, err)
+ }
+ if e, g := testString, string(buf); e != g {
+ t.Errorf("partial read got %q, expected %q", g, e)
+ }
+ res.Body.Close()
+ // Read on the body, even though it's closed
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
+ }
+
+ // Then something small.
+ res, _, err = c.Get(ts.URL + "?chunked=" + chunked)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, e := string(body), testString; g != e {
+ t.Fatalf("body = %q; want %q", g, e)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
+ t.Fatalf("Content-Encoding = %q; want %q", g, e)
+ }
+
+ // Read on the body after it's been fully read:
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
+ }
+ res.Body.Close()
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected Read error after Close; got %d, %v", n, err)
+ }
}
}