diff options
Diffstat (limited to 'src/pkg/crypto/tls/handshake_client.go')
-rw-r--r-- | src/pkg/crypto/tls/handshake_client.go | 362 |
1 files changed, 276 insertions, 86 deletions
diff --git a/src/pkg/crypto/tls/handshake_client.go b/src/pkg/crypto/tls/handshake_client.go index 85e4adefc..a320fde1b 100644 --- a/src/pkg/crypto/tls/handshake_client.go +++ b/src/pkg/crypto/tls/handshake_client.go @@ -12,24 +12,41 @@ import ( "crypto/x509" "encoding/asn1" "errors" + "fmt" "io" + "net" "strconv" ) +type clientHandshakeState struct { + c *Conn + serverHello *serverHelloMsg + hello *clientHelloMsg + suite *cipherSuite + finishedHash finishedHash + masterSecret []byte + session *ClientSessionState +} + func (c *Conn) clientHandshake() error { if c.config == nil { c.config = defaultConfig() } + if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify { + return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") + } + hello := &clientHelloMsg{ - vers: c.config.maxVersion(), - compressionMethods: []uint8{compressionNone}, - random: make([]byte, 32), - ocspStapling: true, - serverName: c.config.ServerName, - supportedCurves: []uint16{curveP256, curveP384, curveP521}, - supportedPoints: []uint8{pointFormatUncompressed}, - nextProtoNeg: len(c.config.NextProtos) > 0, + vers: c.config.maxVersion(), + compressionMethods: []uint8{compressionNone}, + random: make([]byte, 32), + ocspStapling: true, + serverName: c.config.ServerName, + supportedCurves: c.config.curvePreferences(), + supportedPoints: []uint8{pointFormatUncompressed}, + nextProtoNeg: len(c.config.NextProtos) > 0, + secureRenegotiation: true, } possibleCipherSuites := c.config.cipherSuites() @@ -51,21 +68,61 @@ NextCipherSuite: } } - t := uint32(c.config.time().Unix()) - hello.random[0] = byte(t >> 24) - hello.random[1] = byte(t >> 16) - hello.random[2] = byte(t >> 8) - hello.random[3] = byte(t) - _, err := io.ReadFull(c.config.rand(), hello.random[4:]) + _, err := io.ReadFull(c.config.rand(), hello.random) if err != nil { c.sendAlert(alertInternalError) - return errors.New("short read from Rand") + return errors.New("tls: short read from Rand: " + err.Error()) } if hello.vers >= VersionTLS12 { hello.signatureAndHashes = supportedSKXSignatureAlgorithms } + var session *ClientSessionState + var cacheKey string + sessionCache := c.config.ClientSessionCache + if c.config.SessionTicketsDisabled { + sessionCache = nil + } + + if sessionCache != nil { + hello.ticketSupported = true + + // Try to resume a previously negotiated TLS session, if + // available. + cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + candidateSession, ok := sessionCache.Get(cacheKey) + if ok { + // Check that the ciphersuite/version used for the + // previous session are still valid. + cipherSuiteOk := false + for _, id := range hello.cipherSuites { + if id == candidateSession.cipherSuite { + cipherSuiteOk = true + break + } + } + + versOk := candidateSession.vers >= c.config.minVersion() && + candidateSession.vers <= c.config.maxVersion() + if versOk && cipherSuiteOk { + session = candidateSession + } + } + } + + if session != nil { + hello.sessionTicket = session.sessionTicket + // A random session ID is used to detect when the + // server accepted the ticket and is resuming a session + // (see RFC 5077). + hello.sessionId = make([]byte, 16) + if _, err := io.ReadFull(c.config.rand(), hello.sessionId); err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: short read from Rand: " + err.Error()) + } + } + c.writeRecord(recordTypeHandshake, hello.marshal()) msg, err := c.readHandshake() @@ -74,51 +131,103 @@ NextCipherSuite: } serverHello, ok := msg.(*serverHelloMsg) if !ok { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) } vers, ok := c.config.mutualVersion(serverHello.vers) if !ok || vers < VersionTLS10 { // TLS 1.0 is the minimum version supported as a client. - return c.sendAlert(alertProtocolVersion) + c.sendAlert(alertProtocolVersion) + return fmt.Errorf("tls: server selected unsupported protocol version %x", serverHello.vers) } c.vers = vers c.haveVers = true - finishedHash := newFinishedHash(c.vers) - finishedHash.Write(hello.marshal()) - finishedHash.Write(serverHello.marshal()) + suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite) + if suite == nil { + c.sendAlert(alertHandshakeFailure) + return fmt.Errorf("tls: server selected an unsupported cipher suite") + } - if serverHello.compressionMethod != compressionNone { - return c.sendAlert(alertUnexpectedMessage) + hs := &clientHandshakeState{ + c: c, + serverHello: serverHello, + hello: hello, + suite: suite, + finishedHash: newFinishedHash(c.vers), + session: session, } - if !hello.nextProtoNeg && serverHello.nextProtoNeg { - c.sendAlert(alertHandshakeFailure) - return errors.New("server advertised unrequested NPN") + hs.finishedHash.Write(hs.hello.marshal()) + hs.finishedHash.Write(hs.serverHello.marshal()) + + isResume, err := hs.processServerHello() + if err != nil { + return err } - suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite) - if suite == nil { - return c.sendAlert(alertHandshakeFailure) + if isResume { + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(); err != nil { + return err + } + if err := hs.sendFinished(); err != nil { + return err + } + } else { + if err := hs.doFullHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.sendFinished(); err != nil { + return err + } + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(); err != nil { + return err + } } - msg, err = c.readHandshake() + if sessionCache != nil && hs.session != nil && session != hs.session { + sessionCache.Put(cacheKey, hs.session) + } + + c.didResume = isResume + c.handshakeComplete = true + c.cipherSuite = suite.id + return nil +} + +func (hs *clientHandshakeState) doFullHandshake() error { + c := hs.c + + msg, err := c.readHandshake() if err != nil { return err } certMsg, ok := msg.(*certificateMsg) if !ok || len(certMsg.certificates) == 0 { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) } - finishedHash.Write(certMsg.marshal()) + hs.finishedHash.Write(certMsg.marshal()) certs := make([]*x509.Certificate, len(certMsg.certificates)) for i, asn1Data := range certMsg.certificates { cert, err := x509.ParseCertificate(asn1Data) if err != nil { c.sendAlert(alertBadCertificate) - return errors.New("failed to parse certificate from server: " + err.Error()) + return errors.New("tls: failed to parse certificate from server: " + err.Error()) } certs[i] = cert } @@ -148,21 +257,23 @@ NextCipherSuite: case *rsa.PublicKey, *ecdsa.PublicKey: break default: - return c.sendAlert(alertUnsupportedCertificate) + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) } c.peerCertificates = certs - if serverHello.ocspStapling { + if hs.serverHello.ocspStapling { msg, err = c.readHandshake() if err != nil { return err } cs, ok := msg.(*certificateStatusMsg) if !ok { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(cs, msg) } - finishedHash.Write(cs.marshal()) + hs.finishedHash.Write(cs.marshal()) if cs.statusType == statusTypeOCSP { c.ocspResponse = cs.response @@ -174,12 +285,12 @@ NextCipherSuite: return err } - keyAgreement := suite.ka(c.vers) + keyAgreement := hs.suite.ka(c.vers) skx, ok := msg.(*serverKeyExchangeMsg) if ok { - finishedHash.Write(skx.marshal()) - err = keyAgreement.processServerKeyExchange(c.config, hello, serverHello, certs[0], skx) + hs.finishedHash.Write(skx.marshal()) + err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, certs[0], skx) if err != nil { c.sendAlert(alertUnexpectedMessage) return err @@ -208,7 +319,7 @@ NextCipherSuite: // ClientCertificateType, unless there is some external // arrangement to the contrary. - finishedHash.Write(certReq.marshal()) + hs.finishedHash.Write(certReq.marshal()) var rsaAvail, ecdsaAvail bool for _, certType := range certReq.certificateTypes { @@ -271,9 +382,10 @@ NextCipherSuite: shd, ok := msg.(*serverHelloDoneMsg) if !ok { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(shd, msg) } - finishedHash.Write(shd.marshal()) + hs.finishedHash.Write(shd.marshal()) // If the server requested a certificate then we have to send a // Certificate message, even if it's empty because we don't have a @@ -283,17 +395,17 @@ NextCipherSuite: if chainToSend != nil { certMsg.certificates = chainToSend.Certificate } - finishedHash.Write(certMsg.marshal()) + hs.finishedHash.Write(certMsg.marshal()) c.writeRecord(recordTypeHandshake, certMsg.marshal()) } - preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hello, certs[0]) + preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, certs[0]) if err != nil { c.sendAlert(alertInternalError) return err } if ckx != nil { - finishedHash.Write(ckx.marshal()) + hs.finishedHash.Write(ckx.marshal()) c.writeRecord(recordTypeHandshake, ckx.marshal()) } @@ -305,7 +417,7 @@ NextCipherSuite: switch key := c.config.Certificates[0].PrivateKey.(type) { case *ecdsa.PrivateKey: - digest, _, hashId := finishedHash.hashForClientCertificate(signatureECDSA) + digest, _, hashId := hs.finishedHash.hashForClientCertificate(signatureECDSA) r, s, err := ecdsa.Sign(c.config.rand(), key, digest) if err == nil { signed, err = asn1.Marshal(ecdsaSignature{r, s}) @@ -313,7 +425,7 @@ NextCipherSuite: certVerify.signatureAndHash.signature = signatureECDSA certVerify.signatureAndHash.hash = hashId case *rsa.PrivateKey: - digest, hashFunc, hashId := finishedHash.hashForClientCertificate(signatureRSA) + digest, hashFunc, hashId := hs.finishedHash.hashForClientCertificate(signatureRSA) signed, err = rsa.SignPKCS1v15(c.config.rand(), key, hashFunc, digest) certVerify.signatureAndHash.signature = signatureRSA certVerify.signatureAndHash.hash = hashId @@ -321,79 +433,157 @@ NextCipherSuite: err = errors.New("unknown private key type") } if err != nil { - return c.sendAlert(alertInternalError) + c.sendAlert(alertInternalError) + return errors.New("tls: failed to sign handshake with client certificate: " + err.Error()) } certVerify.signature = signed - finishedHash.Write(certVerify.marshal()) + hs.finishedHash.Write(certVerify.marshal()) c.writeRecord(recordTypeHandshake, certVerify.marshal()) } - masterSecret := masterFromPreMasterSecret(c.vers, preMasterSecret, hello.random, serverHello.random) - clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := - keysFromMasterSecret(c.vers, masterSecret, hello.random, serverHello.random, suite.macLen, suite.keyLen, suite.ivLen) + hs.masterSecret = masterFromPreMasterSecret(c.vers, preMasterSecret, hs.hello.random, hs.serverHello.random) + return nil +} + +func (hs *clientHandshakeState) establishKeys() error { + c := hs.c - var clientCipher interface{} - var clientHash macFunction - if suite.cipher != nil { - clientCipher = suite.cipher(clientKey, clientIV, false /* not for reading */) - clientHash = suite.mac(c.vers, clientMAC) + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + var clientCipher, serverCipher interface{} + var clientHash, serverHash macFunction + if hs.suite.cipher != nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) + clientHash = hs.suite.mac(c.vers, clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) + serverHash = hs.suite.mac(c.vers, serverMAC) } else { - clientCipher = suite.aead(clientKey, clientIV) + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) } + + c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) - c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + return nil +} - if serverHello.nextProtoNeg { - nextProto := new(nextProtoMsg) - proto, fallback := mutualProtocol(c.config.NextProtos, serverHello.nextProtos) - nextProto.proto = proto - c.clientProtocol = proto - c.clientProtocolFallback = fallback +func (hs *clientHandshakeState) serverResumedSession() bool { + // If the server responded with the same sessionId then it means the + // sessionTicket is being used to resume a TLS session. + return hs.session != nil && hs.hello.sessionId != nil && + bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId) +} - finishedHash.Write(nextProto.marshal()) - c.writeRecord(recordTypeHandshake, nextProto.marshal()) +func (hs *clientHandshakeState) processServerHello() (bool, error) { + c := hs.c + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertUnexpectedMessage) + return false, errors.New("tls: server selected unsupported compression format") } - finished := new(finishedMsg) - finished.verifyData = finishedHash.clientSum(masterSecret) - finishedHash.Write(finished.marshal()) - c.writeRecord(recordTypeHandshake, finished.marshal()) + if !hs.hello.nextProtoNeg && hs.serverHello.nextProtoNeg { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("server advertised unrequested NPN extension") + } - var serverCipher interface{} - var serverHash macFunction - if suite.cipher != nil { - serverCipher = suite.cipher(serverKey, serverIV, true /* for reading */) - serverHash = suite.mac(c.vers, serverMAC) - } else { - serverCipher = suite.aead(serverKey, serverIV) + if hs.serverResumedSession() { + // Restore masterSecret and peerCerts from previous state + hs.masterSecret = hs.session.masterSecret + c.peerCertificates = hs.session.serverCertificates + return true, nil } - c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) + return false, nil +} + +func (hs *clientHandshakeState) readFinished() error { + c := hs.c + c.readRecord(recordTypeChangeCipherSpec) - if err := c.error(); err != nil { + if err := c.in.error(); err != nil { return err } - msg, err = c.readHandshake() + msg, err := c.readHandshake() if err != nil { return err } serverFinished, ok := msg.(*finishedMsg) if !ok { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverFinished, msg) } - verify := finishedHash.serverSum(masterSecret) + verify := hs.finishedHash.serverSum(hs.masterSecret) if len(verify) != len(serverFinished.verifyData) || subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { - return c.sendAlert(alertHandshakeFailure) + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server's Finished message was incorrect") + } + hs.finishedHash.Write(serverFinished.marshal()) + return nil +} + +func (hs *clientHandshakeState) readSessionTicket() error { + if !hs.serverHello.ticketSupported { + return nil + } + + c := hs.c + msg, err := c.readHandshake() + if err != nil { + return err + } + sessionTicketMsg, ok := msg.(*newSessionTicketMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(sessionTicketMsg, msg) + } + hs.finishedHash.Write(sessionTicketMsg.marshal()) + + hs.session = &ClientSessionState{ + sessionTicket: sessionTicketMsg.ticket, + vers: c.vers, + cipherSuite: hs.suite.id, + masterSecret: hs.masterSecret, + serverCertificates: c.peerCertificates, } - c.handshakeComplete = true - c.cipherSuite = suite.id return nil } +func (hs *clientHandshakeState) sendFinished() error { + c := hs.c + + c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + if hs.serverHello.nextProtoNeg { + nextProto := new(nextProtoMsg) + proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos) + nextProto.proto = proto + c.clientProtocol = proto + c.clientProtocolFallback = fallback + + hs.finishedHash.Write(nextProto.marshal()) + c.writeRecord(recordTypeHandshake, nextProto.marshal()) + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + c.writeRecord(recordTypeHandshake, finished.marshal()) + return nil +} + +// clientSessionCacheKey returns a key used to cache sessionTickets that could +// be used to resume previously negotiated TLS sessions with a server. +func clientSessionCacheKey(serverAddr net.Addr, config *Config) string { + if len(config.ServerName) > 0 { + return config.ServerName + } + return serverAddr.String() +} + // mutualProtocol finds the mutual Next Protocol Negotiation protocol given the // set of client and server supported protocols. The set of client supported // protocols must not be empty. It returns the resulting protocol and flag |