summaryrefslogtreecommitdiff
path: root/src/pkg/crypto/tls/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/crypto/tls/conn.go')
-rw-r--r--src/pkg/crypto/tls/conn.go169
1 files changed, 98 insertions, 71 deletions
diff --git a/src/pkg/crypto/tls/conn.go b/src/pkg/crypto/tls/conn.go
index 2e64b88a6..8f7d2c144 100644
--- a/src/pkg/crypto/tls/conn.go
+++ b/src/pkg/crypto/tls/conn.go
@@ -12,6 +12,7 @@ import (
"crypto/subtle"
"crypto/x509"
"errors"
+ "fmt"
"io"
"net"
"sync"
@@ -27,6 +28,7 @@ type Conn struct {
// constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
+ handshakeErr error // error resulting from handshake
vers uint16 // TLS version
haveVers bool // version has been negotiated
config *Config // configuration passed to constructor
@@ -44,9 +46,6 @@ type Conn struct {
clientProtocol string
clientProtocolFallback bool
- // first permanent error
- connErr
-
// input/output
in, out halfConn // in.Mutex < out.Mutex
rawInput *block // raw input, right off the wire
@@ -56,27 +55,6 @@ type Conn struct {
tmp [16]byte
}
-type connErr struct {
- mu sync.Mutex
- value error
-}
-
-func (e *connErr) setError(err error) error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if e.value == nil {
- e.value = err
- }
- return err
-}
-
-func (e *connErr) error() error {
- e.mu.Lock()
- defer e.mu.Unlock()
- return e.value
-}
-
// Access to net.Conn methods.
// Cannot just embed net.Conn because that would
// export the struct field too.
@@ -104,7 +82,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
-// SetWriteDeadline sets the write deadline on the underlying conneciton.
+// SetWriteDeadline sets the write deadline on the underlying connection.
// A zero value for t means Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetWriteDeadline(t time.Time) error {
@@ -115,6 +93,8 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
// connection, either sending or receiving.
type halfConn struct {
sync.Mutex
+
+ err error // first permanent error
version uint16 // protocol version
cipher interface{} // cipher algorithm
mac macFunction
@@ -128,6 +108,18 @@ type halfConn struct {
inDigestBuf, outDigestBuf []byte
}
+func (hc *halfConn) setErrorLocked(err error) error {
+ hc.err = err
+ return err
+}
+
+func (hc *halfConn) error() error {
+ hc.Lock()
+ err := hc.err
+ hc.Unlock()
+ return err
+}
+
// prepareCipherSpec sets the encryption and MAC states
// that a subsequent changeCipherSpec will use.
func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
@@ -459,6 +451,8 @@ func (b *block) readFromUntil(r io.Reader, n int) error {
m, err := r.Read(b.data[len(b.data):cap(b.data)])
b.data = b.data[0 : len(b.data)+m]
if len(b.data) >= n {
+ // TODO(bradfitz,agl): slightly suspicious
+ // that we're throwing away r.Read's err here.
break
}
if err != nil {
@@ -518,14 +512,17 @@ func (c *Conn) readRecord(want recordType) error {
// else application data. (We don't support renegotiation.)
switch want {
default:
- return c.sendAlert(alertInternalError)
+ c.sendAlert(alertInternalError)
+ return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
case recordTypeHandshake, recordTypeChangeCipherSpec:
if c.handshakeComplete {
- return c.sendAlert(alertInternalError)
+ c.sendAlert(alertInternalError)
+ return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete"))
}
case recordTypeApplicationData:
if !c.handshakeComplete {
- return c.sendAlert(alertInternalError)
+ c.sendAlert(alertInternalError)
+ return c.in.setErrorLocked(errors.New("tls: application data record requested before handshake complete"))
}
}
@@ -544,7 +541,7 @@ Again:
// err = io.ErrUnexpectedEOF
// }
if e, ok := err.(net.Error); !ok || !e.Temporary() {
- c.setError(err)
+ c.in.setErrorLocked(err)
}
return err
}
@@ -556,16 +553,18 @@ Again:
// an SSLv2 client.
if want == recordTypeHandshake && typ == 0x80 {
c.sendAlert(alertProtocolVersion)
- return errors.New("tls: unsupported SSLv2 handshake received")
+ return c.in.setErrorLocked(errors.New("tls: unsupported SSLv2 handshake received"))
}
vers := uint16(b.data[1])<<8 | uint16(b.data[2])
n := int(b.data[3])<<8 | int(b.data[4])
if c.haveVers && vers != c.vers {
- return c.sendAlert(alertProtocolVersion)
+ c.sendAlert(alertProtocolVersion)
+ return c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers))
}
if n > maxCiphertext {
- return c.sendAlert(alertRecordOverflow)
+ c.sendAlert(alertRecordOverflow)
+ return c.in.setErrorLocked(fmt.Errorf("tls: oversized record received with length %d", n))
}
if !c.haveVers {
// First message, be extra suspicious:
@@ -577,7 +576,8 @@ Again:
// well under a kilobyte. If the length is >= 12 kB,
// it's probably not real.
if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
- return c.sendAlert(alertUnexpectedMessage)
+ c.sendAlert(alertUnexpectedMessage)
+ return c.in.setErrorLocked(fmt.Errorf("tls: first record does not look like a TLS handshake"))
}
}
if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
@@ -585,7 +585,7 @@ Again:
err = io.ErrUnexpectedEOF
}
if e, ok := err.(net.Error); !ok || !e.Temporary() {
- c.setError(err)
+ c.in.setErrorLocked(err)
}
return err
}
@@ -594,27 +594,27 @@ Again:
b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
ok, off, err := c.in.decrypt(b)
if !ok {
- return c.sendAlert(err)
+ c.in.setErrorLocked(c.sendAlert(err))
}
b.off = off
data := b.data[b.off:]
if len(data) > maxPlaintext {
- c.sendAlert(alertRecordOverflow)
+ err := c.sendAlert(alertRecordOverflow)
c.in.freeBlock(b)
- return c.error()
+ return c.in.setErrorLocked(err)
}
switch typ {
default:
- c.sendAlert(alertUnexpectedMessage)
+ c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
case recordTypeAlert:
if len(data) != 2 {
- c.sendAlert(alertUnexpectedMessage)
+ c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
if alert(data[1]) == alertCloseNotify {
- c.setError(io.EOF)
+ c.in.setErrorLocked(io.EOF)
break
}
switch data[0] {
@@ -623,24 +623,24 @@ Again:
c.in.freeBlock(b)
goto Again
case alertLevelError:
- c.setError(&net.OpError{Op: "remote error", Err: alert(data[1])})
+ c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
default:
- c.sendAlert(alertUnexpectedMessage)
+ c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
case recordTypeChangeCipherSpec:
if typ != want || len(data) != 1 || data[0] != 1 {
- c.sendAlert(alertUnexpectedMessage)
+ c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
err := c.in.changeCipherSpec()
if err != nil {
- c.sendAlert(err.(alert))
+ c.in.setErrorLocked(c.sendAlert(err.(alert)))
}
case recordTypeApplicationData:
if typ != want {
- c.sendAlert(alertUnexpectedMessage)
+ c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
c.input = b
@@ -649,7 +649,7 @@ Again:
case recordTypeHandshake:
// TODO(rsc): Should at least pick off connection close.
if typ != want {
- return c.sendAlert(alertNoRenegotiation)
+ return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
}
c.hand.Write(data)
}
@@ -657,7 +657,7 @@ Again:
if b != nil {
c.in.freeBlock(b)
}
- return c.error()
+ return c.in.err
}
// sendAlert sends a TLS alert message.
@@ -673,7 +673,7 @@ func (c *Conn) sendAlertLocked(err alert) error {
c.writeRecord(recordTypeAlert, c.tmp[0:2])
// closeNotify is a special case in that it isn't an error:
if err != alertCloseNotify {
- return c.setError(&net.OpError{Op: "local error", Err: err})
+ return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
}
return nil
}
@@ -759,7 +759,7 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
c.tmp[0] = alertLevelError
c.tmp[1] = byte(err.(alert))
c.writeRecord(recordTypeAlert, c.tmp[0:2])
- return n, c.setError(&net.OpError{Op: "local error", Err: err})
+ return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
}
}
return
@@ -770,7 +770,7 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
// c.in.Mutex < L; c.out.Mutex < L.
func (c *Conn) readHandshake() (interface{}, error) {
for c.hand.Len() < 4 {
- if err := c.error(); err != nil {
+ if err := c.in.err; err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
@@ -781,11 +781,10 @@ func (c *Conn) readHandshake() (interface{}, error) {
data := c.hand.Bytes()
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake {
- c.sendAlert(alertInternalError)
- return nil, c.error()
+ return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
}
for c.hand.Len() < 4+n {
- if err := c.error(); err != nil {
+ if err := c.in.err; err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
@@ -799,6 +798,8 @@ func (c *Conn) readHandshake() (interface{}, error) {
m = new(clientHelloMsg)
case typeServerHello:
m = new(serverHelloMsg)
+ case typeNewSessionTicket:
+ m = new(newSessionTicketMsg)
case typeCertificate:
m = new(certificateMsg)
case typeCertificateRequest:
@@ -822,8 +823,7 @@ func (c *Conn) readHandshake() (interface{}, error) {
case typeFinished:
m = new(finishedMsg)
default:
- c.sendAlert(alertUnexpectedMessage)
- return nil, alertUnexpectedMessage
+ return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
// The handshake message unmarshallers
@@ -832,25 +832,24 @@ func (c *Conn) readHandshake() (interface{}, error) {
data = append([]byte(nil), data...)
if !m.unmarshal(data) {
- c.sendAlert(alertUnexpectedMessage)
- return nil, alertUnexpectedMessage
+ return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
return m, nil
}
// Write writes data to the connection.
func (c *Conn) Write(b []byte) (int, error) {
- if err := c.error(); err != nil {
- return 0, err
- }
-
if err := c.Handshake(); err != nil {
- return 0, c.setError(err)
+ return 0, err
}
c.out.Lock()
defer c.out.Unlock()
+ if err := c.out.err; err != nil {
+ return 0, err
+ }
+
if !c.handshakeComplete {
return 0, alertInternalError
}
@@ -869,14 +868,14 @@ func (c *Conn) Write(b []byte) (int, error) {
if _, ok := c.out.cipher.(cipher.BlockMode); ok {
n, err := c.writeRecord(recordTypeApplicationData, b[:1])
if err != nil {
- return n, c.setError(err)
+ return n, c.out.setErrorLocked(err)
}
m, b = 1, b[1:]
}
}
n, err := c.writeRecord(recordTypeApplicationData, b)
- return n + m, c.setError(err)
+ return n + m, c.out.setErrorLocked(err)
}
// Read can be made to time out and return a net.Error with Timeout() == true
@@ -885,6 +884,11 @@ func (c *Conn) Read(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil {
return
}
+ if len(b) == 0 {
+ // Put this after Handshake, in case people were calling
+ // Read(nil) for the side effect of the Handshake.
+ return
+ }
c.in.Lock()
defer c.in.Unlock()
@@ -893,13 +897,13 @@ func (c *Conn) Read(b []byte) (n int, err error) {
// CBC IV. So this loop ignores a limited number of empty records.
const maxConsecutiveEmptyRecords = 100
for emptyRecordCount := 0; emptyRecordCount <= maxConsecutiveEmptyRecords; emptyRecordCount++ {
- for c.input == nil && c.error() == nil {
+ for c.input == nil && c.in.err == nil {
if err := c.readRecord(recordTypeApplicationData); err != nil {
// Soft error, like EAGAIN
return 0, err
}
}
- if err := c.error(); err != nil {
+ if err := c.in.err; err != nil {
return 0, err
}
@@ -909,6 +913,25 @@ func (c *Conn) Read(b []byte) (n int, err error) {
c.input = nil
}
+ // If a close-notify alert is waiting, read it so that
+ // we can return (n, EOF) instead of (n, nil), to signal
+ // to the HTTP response reading goroutine that the
+ // connection is now closed. This eliminates a race
+ // where the HTTP response reading goroutine would
+ // otherwise not observe the EOF until its next read,
+ // by which time a client goroutine might have already
+ // tried to reuse the HTTP connection for a new
+ // request.
+ // See https://codereview.appspot.com/76400046
+ // and http://golang.org/issue/3514
+ if ri := c.rawInput; ri != nil &&
+ n != 0 && err == nil &&
+ c.input == nil && len(ri.data) > 0 && recordType(ri.data[0]) == recordTypeAlert {
+ if recErr := c.readRecord(recordTypeApplicationData); recErr != nil {
+ err = recErr // will be io.EOF on closeNotify
+ }
+ }
+
if n != 0 || err != nil {
return n, err
}
@@ -940,16 +963,19 @@ func (c *Conn) Close() error {
func (c *Conn) Handshake() error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
- if err := c.error(); err != nil {
+ if err := c.handshakeErr; err != nil {
return err
}
if c.handshakeComplete {
return nil
}
+
if c.isClient {
- return c.clientHandshake()
+ c.handshakeErr = c.clientHandshake()
+ } else {
+ c.handshakeErr = c.serverHandshake()
}
- return c.serverHandshake()
+ return c.handshakeErr
}
// ConnectionState returns basic TLS details about the connection.
@@ -960,6 +986,7 @@ func (c *Conn) ConnectionState() ConnectionState {
var state ConnectionState
state.HandshakeComplete = c.handshakeComplete
if c.handshakeComplete {
+ state.Version = c.vers
state.NegotiatedProtocol = c.clientProtocol
state.DidResume = c.didResume
state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
@@ -988,10 +1015,10 @@ func (c *Conn) VerifyHostname(host string) error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if !c.isClient {
- return errors.New("VerifyHostname called on TLS server connection")
+ return errors.New("tls: VerifyHostname called on TLS server connection")
}
if !c.handshakeComplete {
- return errors.New("TLS handshake has not yet been performed")
+ return errors.New("tls: handshake has not yet been performed")
}
return c.peerCertificates[0].VerifyHostname(host)
}