summaryrefslogtreecommitdiff
path: root/src/pkg/http/transport.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/http/transport.go')
-rw-r--r--src/pkg/http/transport.go227
1 files changed, 155 insertions, 72 deletions
diff --git a/src/pkg/http/transport.go b/src/pkg/http/transport.go
index 73a2c2191..c907d85fd 100644
--- a/src/pkg/http/transport.go
+++ b/src/pkg/http/transport.go
@@ -6,12 +6,12 @@ package http
import (
"bufio"
- "bytes"
"compress/gzip"
"crypto/tls"
"encoding/base64"
"fmt"
"io"
+ "io/ioutil"
"log"
"net"
"os"
@@ -24,7 +24,7 @@ import (
// each call to Do and uses HTTP proxies as directed by the
// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy)
// environment variables.
-var DefaultTransport RoundTripper = &Transport{}
+var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment}
// DefaultMaxIdleConnsPerHost is the default value of Transport's
// MaxIdleConnsPerHost.
@@ -36,12 +36,23 @@ const DefaultMaxIdleConnsPerHost = 2
type Transport struct {
lk sync.Mutex
idleConn map[string][]*persistConn
+ altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
// TODO: tunable on global max cached connections
// TODO: tunable on timeout on cached connections
// TODO: optional pipelining
- IgnoreEnvironment bool // don't look at environment variables for proxy configuration
+ // Proxy specifies a function to return a proxy for a given
+ // Request. If the function returns a non-nil error, the
+ // request is aborted with the provided error.
+ // If Proxy is nil or returns a nil *URL, no proxy is used.
+ Proxy func(*Request) (*URL, os.Error)
+
+ // Dial specifies the dial function for creating TCP
+ // connections.
+ // If Dial is nil, net.Dial is used.
+ Dial func(net, addr string) (c net.Conn, err os.Error)
+
DisableKeepAlives bool
DisableCompression bool
@@ -51,6 +62,39 @@ type Transport struct {
MaxIdleConnsPerHost int
}
+// ProxyFromEnvironment returns the URL of the proxy to use for a
+// given request, as indicated by the environment variables
+// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy).
+// Either URL or an error is returned.
+func ProxyFromEnvironment(req *Request) (*URL, os.Error) {
+ proxy := getenvEitherCase("HTTP_PROXY")
+ if proxy == "" {
+ return nil, nil
+ }
+ if !useProxy(canonicalAddr(req.URL)) {
+ return nil, nil
+ }
+ proxyURL, err := ParseRequestURL(proxy)
+ if err != nil {
+ return nil, os.ErrorString("invalid proxy address")
+ }
+ if proxyURL.Host == "" {
+ proxyURL, err = ParseRequestURL("http://" + proxy)
+ if err != nil {
+ return nil, os.ErrorString("invalid proxy address")
+ }
+ }
+ return proxyURL, nil
+}
+
+// ProxyURL returns a proxy function (for use in a Transport)
+// that always returns the same URL.
+func ProxyURL(url *URL) func(*Request) (*URL, os.Error) {
+ return func(*Request) (*URL, os.Error) {
+ return url, nil
+ }
+}
+
// RoundTrip implements the RoundTripper interface.
func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
if req.URL == nil {
@@ -59,7 +103,16 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
}
}
if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
- return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
+ t.lk.Lock()
+ var rt RoundTripper
+ if t.altProto != nil {
+ rt = t.altProto[req.URL.Scheme]
+ }
+ t.lk.Unlock()
+ if rt == nil {
+ return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
+ }
+ return rt.RoundTrip(req)
}
cm, err := t.connectMethodForRequest(req)
@@ -79,6 +132,27 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
return pconn.roundTrip(req)
}
+// RegisterProtocol registers a new protocol with scheme.
+// The Transport will pass requests using the given scheme to rt.
+// It is rt's responsibility to simulate HTTP request semantics.
+//
+// RegisterProtocol can be used by other packages to provide
+// implementations of protocol schemes like "ftp" or "file".
+func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
+ if scheme == "http" || scheme == "https" {
+ panic("protocol " + scheme + " already registered")
+ }
+ t.lk.Lock()
+ defer t.lk.Unlock()
+ if t.altProto == nil {
+ t.altProto = make(map[string]RoundTripper)
+ }
+ if _, exists := t.altProto[scheme]; exists {
+ panic("protocol " + scheme + " already registered")
+ }
+ t.altProto[scheme] = rt
+}
+
// CloseIdleConnections closes any connections which were previously
// connected from previous requests but are now sitting idle in
// a "keep-alive" state. It does not interrupt any connections currently
@@ -101,21 +175,11 @@ func (t *Transport) CloseIdleConnections() {
// Private implementation past this point.
//
-func (t *Transport) getenvEitherCase(k string) string {
- if t.IgnoreEnvironment {
- return ""
- }
- if v := t.getenv(strings.ToUpper(k)); v != "" {
+func getenvEitherCase(k string) string {
+ if v := os.Getenv(strings.ToUpper(k)); v != "" {
return v
}
- return t.getenv(strings.ToLower(k))
-}
-
-func (t *Transport) getenv(k string) string {
- if t.IgnoreEnvironment {
- return ""
- }
- return os.Getenv(k)
+ return os.Getenv(strings.ToLower(k))
}
func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) {
@@ -123,20 +187,12 @@ func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Er
targetScheme: req.URL.Scheme,
targetAddr: canonicalAddr(req.URL),
}
-
- proxy := t.getenvEitherCase("HTTP_PROXY")
- if proxy != "" && t.useProxy(cm.targetAddr) {
- proxyURL, err := ParseRequestURL(proxy)
+ if t.Proxy != nil {
+ var err os.Error
+ cm.proxyURL, err = t.Proxy(req)
if err != nil {
- return nil, os.ErrorString("invalid proxy address")
- }
- if proxyURL.Host == "" {
- proxyURL, err = ParseRequestURL("http://" + proxy)
- if err != nil {
- return nil, os.ErrorString("invalid proxy address")
- }
+ return nil, err
}
- cm.proxyURL = proxyURL
}
return cm, nil
}
@@ -149,10 +205,7 @@ func (cm *connectMethod) proxyAuth() string {
}
proxyInfo := cm.proxyURL.RawUserinfo
if proxyInfo != "" {
- enc := base64.URLEncoding
- encoded := make([]byte, enc.EncodedLen(len(proxyInfo)))
- enc.Encode(encoded, []byte(proxyInfo))
- return "Basic " + string(encoded)
+ return "Basic " + base64.URLEncoding.EncodeToString([]byte(proxyInfo))
}
return ""
}
@@ -207,6 +260,13 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
return
}
+func (t *Transport) dial(network, addr string) (c net.Conn, err os.Error) {
+ if t.Dial != nil {
+ return t.Dial(network, addr)
+ }
+ return net.Dial(network, addr)
+}
+
// getConn dials and creates a new persistConn to the target as
// specified in the connectMethod. This includes doing a proxy CONNECT
// and/or setting up TLS. If this doesn't return an error, the persistConn
@@ -216,7 +276,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
return pc, nil
}
- conn, err := net.Dial("tcp", cm.addr())
+ conn, err := t.dial("tcp", cm.addr())
if err != nil {
if cm.proxyURL != nil {
err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err)
@@ -248,18 +308,22 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
}
}
case cm.targetScheme == "https":
- fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\n", cm.targetAddr)
- fmt.Fprintf(conn, "Host: %s\r\n", cm.targetAddr)
+ connectReq := &Request{
+ Method: "CONNECT",
+ RawURL: cm.targetAddr,
+ Host: cm.targetAddr,
+ Header: make(Header),
+ }
if pa != "" {
- fmt.Fprintf(conn, "Proxy-Authorization: %s\r\n", pa)
+ connectReq.Header.Set("Proxy-Authorization", pa)
}
- fmt.Fprintf(conn, "\r\n")
+ connectReq.Write(conn)
// Read response.
// Okay to use and discard buffered reader here, because
// TLS server will not speak until spoken to.
br := bufio.NewReader(conn)
- resp, err := ReadResponse(br, "CONNECT")
+ resp, err := ReadResponse(br, connectReq)
if err != nil {
conn.Close()
return nil, err
@@ -285,7 +349,6 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
pconn.br = bufio.NewReader(pconn.conn)
pconn.cc = newClientConnFunc(conn, pconn.br)
- pconn.cc.readRes = readResponseWithEOFSignal
go pconn.readLoop()
return pconn, nil
}
@@ -293,7 +356,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
// useProxy returns true if requests to addr should use a proxy,
// according to the NO_PROXY or no_proxy environment variable.
// addr is always a canonicalAddr with a host and port.
-func (t *Transport) useProxy(addr string) bool {
+func useProxy(addr string) bool {
if len(addr) == 0 {
return true
}
@@ -305,16 +368,12 @@ func (t *Transport) useProxy(addr string) bool {
return false
}
if ip := net.ParseIP(host); ip != nil {
- if ip4 := ip.To4(); ip4 != nil && ip4[0] == 127 {
- // 127.0.0.0/8 loopback isn't proxied.
- return false
- }
- if bytes.Equal(ip, net.IPv6loopback) {
+ if ip.IsLoopback() {
return false
}
}
- no_proxy := t.getenvEitherCase("NO_PROXY")
+ no_proxy := getenvEitherCase("NO_PROXY")
if no_proxy == "*" {
return false
}
@@ -447,7 +506,25 @@ func (pc *persistConn) readLoop() {
}
rc := <-pc.reqch
- resp, err := pc.cc.Read(rc.req)
+ resp, err := pc.cc.readUsing(rc.req, func(buf *bufio.Reader, forReq *Request) (*Response, os.Error) {
+ resp, err := ReadResponse(buf, forReq)
+ if err != nil || resp.ContentLength == 0 {
+ return resp, err
+ }
+ if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" {
+ resp.Header.Del("Content-Encoding")
+ resp.Header.Del("Content-Length")
+ resp.ContentLength = -1
+ gzReader, err := gzip.NewReader(resp.Body)
+ if err != nil {
+ pc.close()
+ return nil, err
+ }
+ resp.Body = &readFirstCloseBoth{&discardOnCloseReadCloser{gzReader}, resp.Body}
+ }
+ resp.Body = &bodyEOFSignal{body: resp.Body}
+ return resp, err
+ })
if err == ErrPersistEOF {
// Succeeded, but we can't send any more
@@ -469,6 +546,17 @@ func (pc *persistConn) readLoop() {
waitForBodyRead <- true
}
} else {
+ // When there's no response body, we immediately
+ // reuse the TCP connection (putIdleConn), but
+ // we need to prevent ClientConn.Read from
+ // closing the Response.Body on the next
+ // loop, otherwise it might close the body
+ // before the client code has had a chance to
+ // read it (even though it'll just be 0, EOF).
+ pc.cc.lk.Lock()
+ pc.cc.lastbody = nil
+ pc.cc.lk.Unlock()
+
pc.t.putIdleConn(pc)
}
}
@@ -491,6 +579,11 @@ type responseAndError struct {
type requestAndChan struct {
req *Request
ch chan responseAndError
+
+ // did the Transport (as opposed to the client code) add an
+ // Accept-Encoding gzip header? only if it we set it do
+ // we transparently decode the gzip.
+ addedGzip bool
}
func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
@@ -522,25 +615,12 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
}
ch := make(chan responseAndError, 1)
- pc.reqch <- requestAndChan{req, ch}
+ pc.reqch <- requestAndChan{req, ch, requestedGzip}
re := <-ch
pc.lk.Lock()
pc.numExpectedResponses--
pc.lk.Unlock()
- if re.err == nil && requestedGzip && re.res.Header.Get("Content-Encoding") == "gzip" {
- re.res.Header.Del("Content-Encoding")
- re.res.Header.Del("Content-Length")
- re.res.ContentLength = -1
- esb := re.res.Body.(*bodyEOFSignal)
- gzReader, err := gzip.NewReader(esb.body)
- if err != nil {
- pc.close()
- return nil, err
- }
- esb.body = &readFirstCloseBoth{gzReader, esb.body}
- }
-
return re.res, re.err
}
@@ -572,16 +652,6 @@ func responseIsKeepAlive(res *Response) bool {
return false
}
-// readResponseWithEOFSignal is a wrapper around ReadResponse that replaces
-// the response body with a bodyEOFSignal-wrapped version.
-func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) {
- resp, err = ReadResponse(r, requestMethod)
- if err == nil && resp.ContentLength != 0 {
- resp.Body = &bodyEOFSignal{body: resp.Body}
- }
- return
-}
-
// bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most
// once, right before the final Read() or Close() call returns, but after
// EOF has been seen.
@@ -604,6 +674,9 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) {
}
func (es *bodyEOFSignal) Close() (err os.Error) {
+ if es.isClosed {
+ return nil
+ }
es.isClosed = true
err = es.body.Close()
if err == nil && es.fn != nil {
@@ -628,3 +701,13 @@ func (r *readFirstCloseBoth) Close() os.Error {
}
return nil
}
+
+// discardOnCloseReadCloser consumes all its input on Close.
+type discardOnCloseReadCloser struct {
+ io.ReadCloser
+}
+
+func (d *discardOnCloseReadCloser) Close() os.Error {
+ io.Copy(ioutil.Discard, d.ReadCloser) // ignore errors; likely invalid or already closed
+ return d.ReadCloser.Close()
+}