diff options
Diffstat (limited to 'src/pkg/crypto/tls/tls.go')
-rw-r--r-- | src/pkg/crypto/tls/tls.go | 86 |
1 files changed, 68 insertions, 18 deletions
diff --git a/src/pkg/crypto/tls/tls.go b/src/pkg/crypto/tls/tls.go index 6c67506fc..d50e12029 100644 --- a/src/pkg/crypto/tls/tls.go +++ b/src/pkg/crypto/tls/tls.go @@ -15,6 +15,7 @@ import ( "io/ioutil" "net" "strings" + "time" ) // Server returns a new TLS server side connection @@ -27,9 +28,8 @@ func Server(conn net.Conn, config *Config) *Conn { // Client returns a new TLS client side connection // using conn as the underlying transport. -// Client interprets a nil configuration as equivalent to -// the zero configuration; see the documentation of Config -// for the defaults. +// The config cannot be nil: users must set either ServerName or +// InsecureSkipVerify in the config. func Client(conn net.Conn, config *Config) *Conn { return &Conn{conn: conn, config: config, isClient: true} } @@ -77,24 +77,51 @@ func Listen(network, laddr string, config *Config) (net.Listener, error) { return NewListener(l, config), nil } -// Dial connects to the given network address using net.Dial -// and then initiates a TLS handshake, returning the resulting -// TLS connection. -// Dial interprets a nil configuration as equivalent to -// the zero configuration; see the documentation of Config -// for the defaults. -func Dial(network, addr string, config *Config) (*Conn, error) { - raddr := addr - c, err := net.Dial(network, raddr) +type timeoutError struct{} + +func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +// DialWithDialer connects to the given network address using dialer.Dial and +// then initiates a TLS handshake, returning the resulting TLS connection. Any +// timeout or deadline given in the dialer apply to connection and TLS +// handshake as a whole. +// +// DialWithDialer interprets a nil configuration as equivalent to the zero +// configuration; see the documentation of Config for the defaults. +func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { + // We want the Timeout and Deadline values from dialer to cover the + // whole process: TCP connection and TLS handshake. This means that we + // also need to start our own timers now. + timeout := dialer.Timeout + + if !dialer.Deadline.IsZero() { + deadlineTimeout := dialer.Deadline.Sub(time.Now()) + if timeout == 0 || deadlineTimeout < timeout { + timeout = deadlineTimeout + } + } + + var errChannel chan error + + if timeout != 0 { + errChannel = make(chan error, 2) + time.AfterFunc(timeout, func() { + errChannel <- timeoutError{} + }) + } + + rawConn, err := dialer.Dial(network, addr) if err != nil { return nil, err } - colonPos := strings.LastIndex(raddr, ":") + colonPos := strings.LastIndex(addr, ":") if colonPos == -1 { - colonPos = len(raddr) + colonPos = len(addr) } - hostname := raddr[:colonPos] + hostname := addr[:colonPos] if config == nil { config = defaultConfig() @@ -107,14 +134,37 @@ func Dial(network, addr string, config *Config) (*Conn, error) { c.ServerName = hostname config = &c } - conn := Client(c, config) - if err = conn.Handshake(); err != nil { - c.Close() + + conn := Client(rawConn, config) + + if timeout == 0 { + err = conn.Handshake() + } else { + go func() { + errChannel <- conn.Handshake() + }() + + err = <-errChannel + } + + if err != nil { + rawConn.Close() return nil, err } + return conn, nil } +// Dial connects to the given network address using net.Dial +// and then initiates a TLS handshake, returning the resulting +// TLS connection. +// Dial interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. +func Dial(network, addr string, config *Config) (*Conn, error) { + return DialWithDialer(new(net.Dialer), network, addr, config) +} + // LoadX509KeyPair reads and parses a public/private key pair from a pair of // files. The files must contain PEM encoded data. func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) { |