diff options
Diffstat (limited to 'src/pkg/http/transport.go')
-rw-r--r-- | src/pkg/http/transport.go | 227 |
1 files changed, 155 insertions, 72 deletions
diff --git a/src/pkg/http/transport.go b/src/pkg/http/transport.go index 73a2c2191..c907d85fd 100644 --- a/src/pkg/http/transport.go +++ b/src/pkg/http/transport.go @@ -6,12 +6,12 @@ package http import ( "bufio" - "bytes" "compress/gzip" "crypto/tls" "encoding/base64" "fmt" "io" + "io/ioutil" "log" "net" "os" @@ -24,7 +24,7 @@ import ( // each call to Do and uses HTTP proxies as directed by the // $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy) // environment variables. -var DefaultTransport RoundTripper = &Transport{} +var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment} // DefaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. @@ -36,12 +36,23 @@ const DefaultMaxIdleConnsPerHost = 2 type Transport struct { lk sync.Mutex idleConn map[string][]*persistConn + altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper // TODO: tunable on global max cached connections // TODO: tunable on timeout on cached connections // TODO: optional pipelining - IgnoreEnvironment bool // don't look at environment variables for proxy configuration + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*Request) (*URL, os.Error) + + // Dial specifies the dial function for creating TCP + // connections. + // If Dial is nil, net.Dial is used. + Dial func(net, addr string) (c net.Conn, err os.Error) + DisableKeepAlives bool DisableCompression bool @@ -51,6 +62,39 @@ type Transport struct { MaxIdleConnsPerHost int } +// ProxyFromEnvironment returns the URL of the proxy to use for a +// given request, as indicated by the environment variables +// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy). +// Either URL or an error is returned. +func ProxyFromEnvironment(req *Request) (*URL, os.Error) { + proxy := getenvEitherCase("HTTP_PROXY") + if proxy == "" { + return nil, nil + } + if !useProxy(canonicalAddr(req.URL)) { + return nil, nil + } + proxyURL, err := ParseRequestURL(proxy) + if err != nil { + return nil, os.ErrorString("invalid proxy address") + } + if proxyURL.Host == "" { + proxyURL, err = ParseRequestURL("http://" + proxy) + if err != nil { + return nil, os.ErrorString("invalid proxy address") + } + } + return proxyURL, nil +} + +// ProxyURL returns a proxy function (for use in a Transport) +// that always returns the same URL. +func ProxyURL(url *URL) func(*Request) (*URL, os.Error) { + return func(*Request) (*URL, os.Error) { + return url, nil + } +} + // RoundTrip implements the RoundTripper interface. func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { if req.URL == nil { @@ -59,7 +103,16 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { } } if req.URL.Scheme != "http" && req.URL.Scheme != "https" { - return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} + t.lk.Lock() + var rt RoundTripper + if t.altProto != nil { + rt = t.altProto[req.URL.Scheme] + } + t.lk.Unlock() + if rt == nil { + return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} + } + return rt.RoundTrip(req) } cm, err := t.connectMethodForRequest(req) @@ -79,6 +132,27 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { return pconn.roundTrip(req) } +// RegisterProtocol registers a new protocol with scheme. +// The Transport will pass requests using the given scheme to rt. +// It is rt's responsibility to simulate HTTP request semantics. +// +// RegisterProtocol can be used by other packages to provide +// implementations of protocol schemes like "ftp" or "file". +func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { + if scheme == "http" || scheme == "https" { + panic("protocol " + scheme + " already registered") + } + t.lk.Lock() + defer t.lk.Unlock() + if t.altProto == nil { + t.altProto = make(map[string]RoundTripper) + } + if _, exists := t.altProto[scheme]; exists { + panic("protocol " + scheme + " already registered") + } + t.altProto[scheme] = rt +} + // CloseIdleConnections closes any connections which were previously // connected from previous requests but are now sitting idle in // a "keep-alive" state. It does not interrupt any connections currently @@ -101,21 +175,11 @@ func (t *Transport) CloseIdleConnections() { // Private implementation past this point. // -func (t *Transport) getenvEitherCase(k string) string { - if t.IgnoreEnvironment { - return "" - } - if v := t.getenv(strings.ToUpper(k)); v != "" { +func getenvEitherCase(k string) string { + if v := os.Getenv(strings.ToUpper(k)); v != "" { return v } - return t.getenv(strings.ToLower(k)) -} - -func (t *Transport) getenv(k string) string { - if t.IgnoreEnvironment { - return "" - } - return os.Getenv(k) + return os.Getenv(strings.ToLower(k)) } func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) { @@ -123,20 +187,12 @@ func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Er targetScheme: req.URL.Scheme, targetAddr: canonicalAddr(req.URL), } - - proxy := t.getenvEitherCase("HTTP_PROXY") - if proxy != "" && t.useProxy(cm.targetAddr) { - proxyURL, err := ParseRequestURL(proxy) + if t.Proxy != nil { + var err os.Error + cm.proxyURL, err = t.Proxy(req) if err != nil { - return nil, os.ErrorString("invalid proxy address") - } - if proxyURL.Host == "" { - proxyURL, err = ParseRequestURL("http://" + proxy) - if err != nil { - return nil, os.ErrorString("invalid proxy address") - } + return nil, err } - cm.proxyURL = proxyURL } return cm, nil } @@ -149,10 +205,7 @@ func (cm *connectMethod) proxyAuth() string { } proxyInfo := cm.proxyURL.RawUserinfo if proxyInfo != "" { - enc := base64.URLEncoding - encoded := make([]byte, enc.EncodedLen(len(proxyInfo))) - enc.Encode(encoded, []byte(proxyInfo)) - return "Basic " + string(encoded) + return "Basic " + base64.URLEncoding.EncodeToString([]byte(proxyInfo)) } return "" } @@ -207,6 +260,13 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { return } +func (t *Transport) dial(network, addr string) (c net.Conn, err os.Error) { + if t.Dial != nil { + return t.Dial(network, addr) + } + return net.Dial(network, addr) +} + // getConn dials and creates a new persistConn to the target as // specified in the connectMethod. This includes doing a proxy CONNECT // and/or setting up TLS. If this doesn't return an error, the persistConn @@ -216,7 +276,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { return pc, nil } - conn, err := net.Dial("tcp", cm.addr()) + conn, err := t.dial("tcp", cm.addr()) if err != nil { if cm.proxyURL != nil { err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err) @@ -248,18 +308,22 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { } } case cm.targetScheme == "https": - fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\n", cm.targetAddr) - fmt.Fprintf(conn, "Host: %s\r\n", cm.targetAddr) + connectReq := &Request{ + Method: "CONNECT", + RawURL: cm.targetAddr, + Host: cm.targetAddr, + Header: make(Header), + } if pa != "" { - fmt.Fprintf(conn, "Proxy-Authorization: %s\r\n", pa) + connectReq.Header.Set("Proxy-Authorization", pa) } - fmt.Fprintf(conn, "\r\n") + connectReq.Write(conn) // Read response. // Okay to use and discard buffered reader here, because // TLS server will not speak until spoken to. br := bufio.NewReader(conn) - resp, err := ReadResponse(br, "CONNECT") + resp, err := ReadResponse(br, connectReq) if err != nil { conn.Close() return nil, err @@ -285,7 +349,6 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { pconn.br = bufio.NewReader(pconn.conn) pconn.cc = newClientConnFunc(conn, pconn.br) - pconn.cc.readRes = readResponseWithEOFSignal go pconn.readLoop() return pconn, nil } @@ -293,7 +356,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { // useProxy returns true if requests to addr should use a proxy, // according to the NO_PROXY or no_proxy environment variable. // addr is always a canonicalAddr with a host and port. -func (t *Transport) useProxy(addr string) bool { +func useProxy(addr string) bool { if len(addr) == 0 { return true } @@ -305,16 +368,12 @@ func (t *Transport) useProxy(addr string) bool { return false } if ip := net.ParseIP(host); ip != nil { - if ip4 := ip.To4(); ip4 != nil && ip4[0] == 127 { - // 127.0.0.0/8 loopback isn't proxied. - return false - } - if bytes.Equal(ip, net.IPv6loopback) { + if ip.IsLoopback() { return false } } - no_proxy := t.getenvEitherCase("NO_PROXY") + no_proxy := getenvEitherCase("NO_PROXY") if no_proxy == "*" { return false } @@ -447,7 +506,25 @@ func (pc *persistConn) readLoop() { } rc := <-pc.reqch - resp, err := pc.cc.Read(rc.req) + resp, err := pc.cc.readUsing(rc.req, func(buf *bufio.Reader, forReq *Request) (*Response, os.Error) { + resp, err := ReadResponse(buf, forReq) + if err != nil || resp.ContentLength == 0 { + return resp, err + } + if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" { + resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Length") + resp.ContentLength = -1 + gzReader, err := gzip.NewReader(resp.Body) + if err != nil { + pc.close() + return nil, err + } + resp.Body = &readFirstCloseBoth{&discardOnCloseReadCloser{gzReader}, resp.Body} + } + resp.Body = &bodyEOFSignal{body: resp.Body} + return resp, err + }) if err == ErrPersistEOF { // Succeeded, but we can't send any more @@ -469,6 +546,17 @@ func (pc *persistConn) readLoop() { waitForBodyRead <- true } } else { + // When there's no response body, we immediately + // reuse the TCP connection (putIdleConn), but + // we need to prevent ClientConn.Read from + // closing the Response.Body on the next + // loop, otherwise it might close the body + // before the client code has had a chance to + // read it (even though it'll just be 0, EOF). + pc.cc.lk.Lock() + pc.cc.lastbody = nil + pc.cc.lk.Unlock() + pc.t.putIdleConn(pc) } } @@ -491,6 +579,11 @@ type responseAndError struct { type requestAndChan struct { req *Request ch chan responseAndError + + // did the Transport (as opposed to the client code) add an + // Accept-Encoding gzip header? only if it we set it do + // we transparently decode the gzip. + addedGzip bool } func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { @@ -522,25 +615,12 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { } ch := make(chan responseAndError, 1) - pc.reqch <- requestAndChan{req, ch} + pc.reqch <- requestAndChan{req, ch, requestedGzip} re := <-ch pc.lk.Lock() pc.numExpectedResponses-- pc.lk.Unlock() - if re.err == nil && requestedGzip && re.res.Header.Get("Content-Encoding") == "gzip" { - re.res.Header.Del("Content-Encoding") - re.res.Header.Del("Content-Length") - re.res.ContentLength = -1 - esb := re.res.Body.(*bodyEOFSignal) - gzReader, err := gzip.NewReader(esb.body) - if err != nil { - pc.close() - return nil, err - } - esb.body = &readFirstCloseBoth{gzReader, esb.body} - } - return re.res, re.err } @@ -572,16 +652,6 @@ func responseIsKeepAlive(res *Response) bool { return false } -// readResponseWithEOFSignal is a wrapper around ReadResponse that replaces -// the response body with a bodyEOFSignal-wrapped version. -func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) { - resp, err = ReadResponse(r, requestMethod) - if err == nil && resp.ContentLength != 0 { - resp.Body = &bodyEOFSignal{body: resp.Body} - } - return -} - // bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most // once, right before the final Read() or Close() call returns, but after // EOF has been seen. @@ -604,6 +674,9 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) { } func (es *bodyEOFSignal) Close() (err os.Error) { + if es.isClosed { + return nil + } es.isClosed = true err = es.body.Close() if err == nil && es.fn != nil { @@ -628,3 +701,13 @@ func (r *readFirstCloseBoth) Close() os.Error { } return nil } + +// discardOnCloseReadCloser consumes all its input on Close. +type discardOnCloseReadCloser struct { + io.ReadCloser +} + +func (d *discardOnCloseReadCloser) Close() os.Error { + io.Copy(ioutil.Discard, d.ReadCloser) // ignore errors; likely invalid or already closed + return d.ReadCloser.Close() +} |