diff options
author | Russ Cox <rsc@golang.org> | 2010-04-26 22:19:04 -0700 |
---|---|---|
committer | Russ Cox <rsc@golang.org> | 2010-04-26 22:19:04 -0700 |
commit | 95e18c29897387f28e3e73cd164474977ff56582 (patch) | |
tree | bfe11232af96d9016ebbe9175bb12b795a3008e5 | |
parent | 0415bdc67536a8d8c51aa26f2cd9b9cdf2d3967b (diff) | |
download | golang-95e18c29897387f28e3e73cd164474977ff56582.tar.gz |
crypto/tls: simpler implementation of record layer
Depends on CL 957045, 980043, 1004043.
Fixes issue 715.
R=agl1, agl
CC=golang-dev
http://codereview.appspot.com/943043
-rw-r--r-- | src/pkg/crypto/tls/Makefile | 6 | ||||
-rw-r--r-- | src/pkg/crypto/tls/alert.go | 88 | ||||
-rw-r--r-- | src/pkg/crypto/tls/common.go | 51 | ||||
-rw-r--r-- | src/pkg/crypto/tls/conn.go | 635 | ||||
-rw-r--r-- | src/pkg/crypto/tls/handshake_client.go | 188 | ||||
-rw-r--r-- | src/pkg/crypto/tls/handshake_messages.go | 22 | ||||
-rw-r--r-- | src/pkg/crypto/tls/handshake_messages_test.go | 6 | ||||
-rw-r--r-- | src/pkg/crypto/tls/handshake_server.go | 170 | ||||
-rw-r--r-- | src/pkg/crypto/tls/handshake_server_test.go | 329 | ||||
-rw-r--r-- | src/pkg/crypto/tls/record_process.go | 302 | ||||
-rw-r--r-- | src/pkg/crypto/tls/record_process_test.go | 137 | ||||
-rw-r--r-- | src/pkg/crypto/tls/record_read.go | 42 | ||||
-rw-r--r-- | src/pkg/crypto/tls/record_read_test.go | 73 | ||||
-rw-r--r-- | src/pkg/crypto/tls/record_write.go | 170 | ||||
-rw-r--r-- | src/pkg/crypto/tls/tls.go | 158 |
15 files changed, 1065 insertions, 1312 deletions
diff --git a/src/pkg/crypto/tls/Makefile b/src/pkg/crypto/tls/Makefile index 55c9d87cf..5e25bd43a 100644 --- a/src/pkg/crypto/tls/Makefile +++ b/src/pkg/crypto/tls/Makefile @@ -7,15 +7,13 @@ include ../../../Make.$(GOARCH) TARG=crypto/tls GOFILES=\ alert.go\ + ca_set.go\ common.go\ + conn.go\ handshake_client.go\ handshake_messages.go\ handshake_server.go\ prf.go\ - record_process.go\ - record_read.go\ - record_write.go\ - ca_set.go\ tls.go\ include ../../../Make.pkg diff --git a/src/pkg/crypto/tls/alert.go b/src/pkg/crypto/tls/alert.go index 2f740b39e..3b9e0e241 100644 --- a/src/pkg/crypto/tls/alert.go +++ b/src/pkg/crypto/tls/alert.go @@ -4,40 +4,70 @@ package tls -type alertLevel int -type alertType int +import "strconv" + +type alert uint8 const ( - alertLevelWarning alertLevel = 1 - alertLevelError alertLevel = 2 + // alert level + alertLevelWarning = 1 + alertLevelError = 2 ) const ( - alertCloseNotify alertType = 0 - alertUnexpectedMessage alertType = 10 - alertBadRecordMAC alertType = 20 - alertDecryptionFailed alertType = 21 - alertRecordOverflow alertType = 22 - alertDecompressionFailure alertType = 30 - alertHandshakeFailure alertType = 40 - alertBadCertificate alertType = 42 - alertUnsupportedCertificate alertType = 43 - alertCertificateRevoked alertType = 44 - alertCertificateExpired alertType = 45 - alertCertificateUnknown alertType = 46 - alertIllegalParameter alertType = 47 - alertUnknownCA alertType = 48 - alertAccessDenied alertType = 49 - alertDecodeError alertType = 50 - alertDecryptError alertType = 51 - alertProtocolVersion alertType = 70 - alertInsufficientSecurity alertType = 71 - alertInternalError alertType = 80 - alertUserCanceled alertType = 90 - alertNoRenegotiation alertType = 100 + alertCloseNotify alert = 0 + alertUnexpectedMessage alert = 10 + alertBadRecordMAC alert = 20 + alertDecryptionFailed alert = 21 + alertRecordOverflow alert = 22 + alertDecompressionFailure alert = 30 + alertHandshakeFailure alert = 40 + alertBadCertificate alert = 42 + alertUnsupportedCertificate alert = 43 + alertCertificateRevoked alert = 44 + alertCertificateExpired alert = 45 + alertCertificateUnknown alert = 46 + alertIllegalParameter alert = 47 + alertUnknownCA alert = 48 + alertAccessDenied alert = 49 + alertDecodeError alert = 50 + alertDecryptError alert = 51 + alertProtocolVersion alert = 70 + alertInsufficientSecurity alert = 71 + alertInternalError alert = 80 + alertUserCanceled alert = 90 + alertNoRenegotiation alert = 100 ) -type alert struct { - level alertLevel - error alertType +var alertText = map[alert]string{ + alertCloseNotify: "close notify", + alertUnexpectedMessage: "unexpected message", + alertBadRecordMAC: "bad record MAC", + alertDecryptionFailed: "decryption failed", + alertRecordOverflow: "record overflow", + alertDecompressionFailure: "decompression failure", + alertHandshakeFailure: "handshake failure", + alertBadCertificate: "bad certificate", + alertUnsupportedCertificate: "unsupported certificate", + alertCertificateRevoked: "revoked certificate", + alertCertificateExpired: "expired certificate", + alertCertificateUnknown: "unknown certificate", + alertIllegalParameter: "illegal parameter", + alertUnknownCA: "unknown certificate authority", + alertAccessDenied: "access denied", + alertDecodeError: "error decoding message", + alertDecryptError: "error decrypting message", + alertProtocolVersion: "protocol version not supported", + alertInsufficientSecurity: "insufficient security level", + alertInternalError: "internal error", + alertUserCanceled: "user canceled", + alertNoRenegotiation: "no renegotiation", +} + +func (e alert) String() string { + s, ok := alertText[e] + if ok { + return s + } + return "alert(" + strconv.Itoa(int(e)) + ")" } diff --git a/src/pkg/crypto/tls/common.go b/src/pkg/crypto/tls/common.go index ef54a1db7..56c22cf7d 100644 --- a/src/pkg/crypto/tls/common.go +++ b/src/pkg/crypto/tls/common.go @@ -10,22 +10,18 @@ import ( "io" "io/ioutil" "once" - "os" "time" ) const ( - // maxTLSCiphertext is the maximum length of a plaintext payload. - maxTLSPlaintext = 16384 - // maxTLSCiphertext is the maximum length payload after compression and encryption. - maxTLSCiphertext = 16384 + 2048 - // maxHandshakeMsg is the largest single handshake message that we'll buffer. - maxHandshakeMsg = 65536 - // defaultMajor and defaultMinor are the maximum TLS version that we support. - defaultMajor = 3 - defaultMinor = 2 -) + maxPlaintext = 16384 // maximum plaintext payload length + maxCiphertext = 16384 + 2048 // maximum ciphertext payload length + recordHeaderLen = 5 // record header length + maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) + minVersion = 0x0301 // minimum supported version - TLS 1.0 + maxVersion = 0x0302 // maximum supported version - TLS 1.1 +) // TLS record types. type recordType uint8 @@ -67,7 +63,7 @@ var ( type ConnectionState struct { HandshakeComplete bool CipherSuite string - Error alertType + Error alert NegotiatedProtocol string } @@ -99,6 +95,7 @@ type record struct { type handshakeMessage interface { marshal() []byte + unmarshal([]byte) bool } type encryptor interface { @@ -108,34 +105,16 @@ type encryptor interface { // mutualVersion returns the protocol version to use given the advertised // version of the peer. -func mutualVersion(theirMajor, theirMinor uint8) (major, minor uint8, ok bool) { - // We don't deal with peers < TLS 1.0 (aka version 3.1). - if theirMajor < 3 || theirMajor == 3 && theirMinor < 1 { - return 0, 0, false +func mutualVersion(vers uint16) (uint16, bool) { + if vers < minVersion { + return 0, false } - major = 3 - minor = 2 - if theirMinor < minor { - minor = theirMinor + if vers > maxVersion { + vers = maxVersion } - ok = true - return + return vers, true } -// A nop implements the NULL encryption and MAC algorithms. -type nop struct{} - -func (nop) XORKeyStream(buf []byte) {} - -func (nop) Write(buf []byte) (int, os.Error) { return len(buf), nil } - -func (nop) Sum() []byte { return nil } - -func (nop) Reset() {} - -func (nop) Size() int { return 0 } - - // The defaultConfig is used in place of a nil *Config in the TLS server and client. var varDefaultConfig *Config diff --git a/src/pkg/crypto/tls/conn.go b/src/pkg/crypto/tls/conn.go new file mode 100644 index 000000000..d0e8464d5 --- /dev/null +++ b/src/pkg/crypto/tls/conn.go @@ -0,0 +1,635 @@ +// TLS low level connection and record layer + +package tls + +import ( + "bytes" + "crypto/subtle" + "hash" + "io" + "net" + "os" + "sync" +) + +// A Conn represents a secured connection. +// It implements the net.Conn interface. +type Conn struct { + // constant + conn net.Conn + isClient bool + + // constant after handshake; protected by handshakeMutex + handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex + vers uint16 // TLS version + haveVers bool // version has been negotiated + config *Config // configuration passed to constructor + handshakeComplete bool + cipherSuite uint16 + + clientProtocol string + + // first permanent error + errMutex sync.Mutex + err os.Error + + // input/output + in, out halfConn // in.Mutex < out.Mutex + rawInput *block // raw input, right off the wire + input *block // application data waiting to be read + hand bytes.Buffer // handshake data waiting to be read + + tmp [16]byte +} + +func (c *Conn) setError(err os.Error) os.Error { + c.errMutex.Lock() + defer c.errMutex.Unlock() + + if c.err == nil { + c.err = err + } + return err +} + +func (c *Conn) error() os.Error { + c.errMutex.Lock() + defer c.errMutex.Unlock() + + return c.err +} + +// Access to net.Conn methods. +// Cannot just embed net.Conn because that would +// export the struct field too. + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetTimeout sets the read deadline associated with the connection. +// There is no write deadline. +func (c *Conn) SetTimeout(nsec int64) os.Error { + return c.conn.SetTimeout(nsec) +} + +// SetReadTimeout sets the time (in nanoseconds) that +// Read will wait for data before returning os.EAGAIN. +// Setting nsec == 0 (the default) disables the deadline. +func (c *Conn) SetReadTimeout(nsec int64) os.Error { + return c.conn.SetReadTimeout(nsec) +} + +// SetWriteTimeout exists to satisfy the net.Conn interface +// but is not implemented by TLS. It always returns an error. +func (c *Conn) SetWriteTimeout(nsec int64) os.Error { + return os.NewError("TLS does not support SetWriteTimeout") +} + +// A halfConn represents one direction of the record layer +// connection, either sending or receiving. +type halfConn struct { + sync.Mutex + crypt encryptor // encryption state + mac hash.Hash // MAC algorithm + seq [8]byte // 64-bit sequence number + bfree *block // list of free blocks + + nextCrypt encryptor // next encryption state + nextMac hash.Hash // next MAC algorithm +} + +// prepareCipherSpec sets the encryption and MAC states +// that a subsequent changeCipherSpec will use. +func (hc *halfConn) prepareCipherSpec(crypt encryptor, mac hash.Hash) { + hc.nextCrypt = crypt + hc.nextMac = mac +} + +// changeCipherSpec changes the encryption and MAC states +// to the ones previously passed to prepareCipherSpec. +func (hc *halfConn) changeCipherSpec() os.Error { + if hc.nextCrypt == nil { + return alertInternalError + } + hc.crypt = hc.nextCrypt + hc.mac = hc.nextMac + hc.nextCrypt = nil + hc.nextMac = nil + return nil +} + +// incSeq increments the sequence number. +func (hc *halfConn) incSeq() { + for i := 7; i >= 0; i-- { + hc.seq[i]++ + if hc.seq[i] != 0 { + return + } + } + + // Not allowed to let sequence number wrap. + // Instead, must renegotiate before it does. + // Not likely enough to bother. + panic("TLS: sequence number wraparound") +} + +// resetSeq resets the sequence number to zero. +func (hc *halfConn) resetSeq() { + for i := range hc.seq { + hc.seq[i] = 0 + } +} + +// decrypt checks and strips the mac and decrypts the data in b. +func (hc *halfConn) decrypt(b *block) (bool, alert) { + // pull out payload + payload := b.data[recordHeaderLen:] + + // decrypt + if hc.crypt != nil { + hc.crypt.XORKeyStream(payload) + } + + // check, strip mac + if hc.mac != nil { + if len(payload) < hc.mac.Size() { + return false, alertBadRecordMAC + } + + // strip mac off payload, b.data + n := len(payload) - hc.mac.Size() + b.data[3] = byte(n >> 8) + b.data[4] = byte(n) + b.data = b.data[0 : recordHeaderLen+n] + remoteMAC := payload[n:] + + hc.mac.Reset() + hc.mac.Write(&hc.seq) + hc.incSeq() + hc.mac.Write(b.data) + + if subtle.ConstantTimeCompare(hc.mac.Sum(), remoteMAC) != 1 { + return false, alertBadRecordMAC + } + } + + return true, 0 +} + +// encrypt encrypts and macs the data in b. +func (hc *halfConn) encrypt(b *block) (bool, alert) { + // mac + if hc.mac != nil { + hc.mac.Reset() + hc.mac.Write(&hc.seq) + hc.incSeq() + hc.mac.Write(b.data) + mac := hc.mac.Sum() + n := len(b.data) + b.resize(n + len(mac)) + copy(b.data[n:], mac) + + // update length to include mac + n = len(b.data) - recordHeaderLen + b.data[3] = byte(n >> 8) + b.data[4] = byte(n) + } + + // encrypt + if hc.crypt != nil { + hc.crypt.XORKeyStream(b.data[recordHeaderLen:]) + } + + return true, 0 +} + +// A block is a simple data buffer. +type block struct { + data []byte + off int // index for Read + link *block +} + +// resize resizes block to be n bytes, growing if necessary. +func (b *block) resize(n int) { + if n > cap(b.data) { + b.reserve(n) + } + b.data = b.data[0:n] +} + +// reserve makes sure that block contains a capacity of at least n bytes. +func (b *block) reserve(n int) { + if cap(b.data) >= n { + return + } + m := cap(b.data) + if m == 0 { + m = 1024 + } + for m < n { + m *= 2 + } + data := make([]byte, len(b.data), m) + copy(data, b.data) + b.data = data +} + +// readFromUntil reads from r into b until b contains at least n bytes +// or else returns an error. +func (b *block) readFromUntil(r io.Reader, n int) os.Error { + // quick case + if len(b.data) >= n { + return nil + } + + // read until have enough. + b.reserve(n) + for { + m, err := r.Read(b.data[len(b.data):cap(b.data)]) + b.data = b.data[0 : len(b.data)+m] + if len(b.data) >= n { + break + } + if err != nil { + return err + } + } + return nil +} + +func (b *block) Read(p []byte) (n int, err os.Error) { + n = copy(p, b.data[b.off:]) + b.off += n + return +} + +// newBlock allocates a new block, from hc's free list if possible. +func (hc *halfConn) newBlock() *block { + b := hc.bfree + if b == nil { + return new(block) + } + hc.bfree = b.link + b.link = nil + b.resize(0) + return b +} + +// freeBlock returns a block to hc's free list. +// The protocol is such that each side only has a block or two on +// its free list at a time, so there's no need to worry about +// trimming the list, etc. +func (hc *halfConn) freeBlock(b *block) { + b.link = hc.bfree + hc.bfree = b +} + +// splitBlock splits a block after the first n bytes, +// returning a block with those n bytes and a +// block with the remaindec. the latter may be nil. +func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) { + if len(b.data) <= n { + return b, nil + } + bb := hc.newBlock() + bb.resize(len(b.data) - n) + copy(bb.data, b.data[n:]) + b.data = b.data[0:n] + return b, bb +} + +// readRecord reads the next TLS record from the connection +// and updates the record layer state. +// c.in.Mutex <= L; c.input == nil. +func (c *Conn) readRecord(want recordType) os.Error { + // Caller must be in sync with connection: + // handshake data if handshake not yet completed, + // else application data. (We don't support renegotiation.) + switch want { + default: + return c.sendAlert(alertInternalError) + case recordTypeHandshake, recordTypeChangeCipherSpec: + if c.handshakeComplete { + return c.sendAlert(alertInternalError) + } + case recordTypeApplicationData: + if !c.handshakeComplete { + return c.sendAlert(alertInternalError) + } + } + +Again: + if c.rawInput == nil { + c.rawInput = c.in.newBlock() + } + b := c.rawInput + + // Read header, payload. + if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil { + // RFC suggests that EOF without an alertCloseNotify is + // an error, but popular web sites seem to do this, + // so we can't make it an error. + // if err == os.EOF { + // err = io.ErrUnexpectedEOF + // } + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.setError(err) + } + return err + } + typ := recordType(b.data[0]) + vers := uint16(b.data[1])<<8 | uint16(b.data[2]) + n := int(b.data[3])<<8 | int(b.data[4]) + if c.haveVers && vers != c.vers { + return c.sendAlert(alertProtocolVersion) + } + if n > maxCiphertext { + return c.sendAlert(alertRecordOverflow) + } + if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil { + if err == os.EOF { + err = io.ErrUnexpectedEOF + } + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.setError(err) + } + return err + } + + // Process message. + b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n) + b.off = recordHeaderLen + if ok, err := c.in.decrypt(b); !ok { + return c.sendAlert(err) + } + data := b.data[b.off:] + if len(data) > maxPlaintext { + c.sendAlert(alertRecordOverflow) + c.in.freeBlock(b) + return c.error() + } + + switch typ { + default: + c.sendAlert(alertUnexpectedMessage) + + case recordTypeAlert: + if len(data) != 2 { + c.sendAlert(alertUnexpectedMessage) + break + } + if alert(data[1]) == alertCloseNotify { + c.setError(os.EOF) + break + } + switch data[0] { + case alertLevelWarning: + // drop on the floor + c.in.freeBlock(b) + goto Again + case alertLevelError: + c.setError(&net.OpError{Op: "remote error", Error: alert(data[1])}) + default: + c.sendAlert(alertUnexpectedMessage) + } + + case recordTypeChangeCipherSpec: + if typ != want || len(data) != 1 || data[0] != 1 { + c.sendAlert(alertUnexpectedMessage) + break + } + err := c.in.changeCipherSpec() + if err != nil { + c.sendAlert(err.(alert)) + } + + case recordTypeApplicationData: + if typ != want { + c.sendAlert(alertUnexpectedMessage) + break + } + c.input = b + b = nil + + case recordTypeHandshake: + // TODO(rsc): Should at least pick off connection close. + if typ != want { + return c.sendAlert(alertNoRenegotiation) + } + c.hand.Write(data) + } + + if b != nil { + c.in.freeBlock(b) + } + return c.error() +} + +// sendAlert sends a TLS alert message. +// c.out.Mutex <= L. +func (c *Conn) sendAlertLocked(err alert) os.Error { + c.tmp[0] = alertLevelError + if err == alertNoRenegotiation { + c.tmp[0] = alertLevelWarning + } + c.tmp[1] = byte(err) + c.writeRecord(recordTypeAlert, c.tmp[0:2]) + return c.setError(&net.OpError{Op: "local error", Error: err}) +} + +// sendAlert sends a TLS alert message. +// L < c.out.Mutex. +func (c *Conn) sendAlert(err alert) os.Error { + c.out.Lock() + defer c.out.Unlock() + return c.sendAlertLocked(err) +} + +// writeRecord writes a TLS record with the given type and payload +// to the connection and updates the record layer state. +// c.out.Mutex <= L. +func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err os.Error) { + b := c.out.newBlock() + for len(data) > 0 { + m := len(data) + if m > maxPlaintext { + m = maxPlaintext + } + b.resize(recordHeaderLen + m) + b.data[0] = byte(typ) + vers := c.vers + if vers == 0 { + vers = maxVersion + } + b.data[1] = byte(vers >> 8) + b.data[2] = byte(vers) + b.data[3] = byte(m >> 8) + b.data[4] = byte(m) + copy(b.data[recordHeaderLen:], data) + c.out.encrypt(b) + _, err = c.conn.Write(b.data) + if err != nil { + break + } + n += m + data = data[m:] + } + c.out.freeBlock(b) + + if typ == recordTypeChangeCipherSpec { + err = c.out.changeCipherSpec() + if err != nil { + // Cannot call sendAlert directly, + // because we already hold c.out.Mutex. + c.tmp[0] = alertLevelError + c.tmp[1] = byte(err.(alert)) + c.writeRecord(recordTypeAlert, c.tmp[0:2]) + c.err = &net.OpError{Op: "local error", Error: err} + return n, c.err + } + } + return +} + +// readHandshake reads the next handshake message from +// the record layer. +// c.in.Mutex < L; c.out.Mutex < L. +func (c *Conn) readHandshake() (interface{}, os.Error) { + for c.hand.Len() < 4 { + if c.err != nil { + return nil, c.err + } + c.readRecord(recordTypeHandshake) + } + + data := c.hand.Bytes() + n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if n > maxHandshake { + c.sendAlert(alertInternalError) + return nil, c.err + } + for c.hand.Len() < 4+n { + if c.err != nil { + return nil, c.err + } + c.readRecord(recordTypeHandshake) + } + data = c.hand.Next(4 + n) + var m handshakeMessage + switch data[0] { + case typeClientHello: + m = new(clientHelloMsg) + case typeServerHello: + m = new(serverHelloMsg) + case typeCertificate: + m = new(certificateMsg) + case typeServerHelloDone: + m = new(serverHelloDoneMsg) + case typeClientKeyExchange: + m = new(clientKeyExchangeMsg) + case typeNextProtocol: + m = new(nextProtoMsg) + case typeFinished: + m = new(finishedMsg) + default: + c.sendAlert(alertUnexpectedMessage) + return nil, alertUnexpectedMessage + } + + // The handshake message unmarshallers + // expect to be able to keep references to data, + // so pass in a fresh copy that won't be overwritten. + data = bytes.Add(nil, data) + + if !m.unmarshal(data) { + c.sendAlert(alertUnexpectedMessage) + return nil, alertUnexpectedMessage + } + return m, nil +} + +// Write writes data to the connection. +func (c *Conn) Write(b []byte) (n int, err os.Error) { + if err = c.Handshake(); err != nil { + return + } + + c.out.Lock() + defer c.out.Unlock() + + if !c.handshakeComplete { + return 0, alertInternalError + } + if c.err != nil { + return 0, c.err + } + return c.writeRecord(recordTypeApplicationData, b) +} + +// Read can be made to time out and return err == os.EAGAIN +// after a fixed time limit; see SetTimeout and SetReadTimeout. +func (c *Conn) Read(b []byte) (n int, err os.Error) { + if err = c.Handshake(); err != nil { + return + } + + c.in.Lock() + defer c.in.Unlock() + + for c.input == nil && c.err == nil { + c.readRecord(recordTypeApplicationData) + } + if c.err != nil { + return 0, c.err + } + n, err = c.input.Read(b) + if c.input.off >= len(c.input.data) { + c.in.freeBlock(c.input) + c.input = nil + } + return n, nil +} + +// Close closes the connection. +func (c *Conn) Close() os.Error { + if err := c.Handshake(); err != nil { + return err + } + return c.sendAlert(alertCloseNotify) +} + +// Handshake runs the client or server handshake +// protocol if it has not yet been run. +// Most uses of this packge need not call Handshake +// explicitly: the first Read or Write will call it automatically. +func (c *Conn) Handshake() os.Error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + if err := c.error(); err != nil { + return err + } + if c.handshakeComplete { + return nil + } + if c.isClient { + return c.clientHandshake() + } + return c.serverHandshake() +} + +// If c is a TLS server, ClientConnection returns the protocol +// requested by the client during the TLS handshake. +// Handshake must have been called already. +func (c *Conn) ClientConnection() string { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + return c.clientProtocol +} diff --git a/src/pkg/crypto/tls/handshake_client.go b/src/pkg/crypto/tls/handshake_client.go index 8cc6b7409..dd3009802 100644 --- a/src/pkg/crypto/tls/handshake_client.go +++ b/src/pkg/crypto/tls/handshake_client.go @@ -12,74 +12,63 @@ import ( "crypto/subtle" "crypto/x509" "io" + "os" ) -// A serverHandshake performs the server side of the TLS 1.1 handshake protocol. -type clientHandshake struct { - writeChan chan<- interface{} - controlChan chan<- interface{} - msgChan <-chan interface{} - config *Config -} - -func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) { - h.writeChan = writeChan - h.controlChan = controlChan - h.msgChan = msgChan - h.config = config - - defer close(writeChan) - defer close(controlChan) - +func (c *Conn) clientHandshake() os.Error { finishedHash := newFinishedHash() + config := defaultConfig() + hello := &clientHelloMsg{ - major: defaultMajor, - minor: defaultMinor, + vers: maxVersion, cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, compressionMethods: []uint8{compressionNone}, random: make([]byte, 32), } - currentTime := uint32(config.Time()) - hello.random[0] = byte(currentTime >> 24) - hello.random[1] = byte(currentTime >> 16) - hello.random[2] = byte(currentTime >> 8) - hello.random[3] = byte(currentTime) + t := uint32(config.Time()) + hello.random[0] = byte(t >> 24) + hello.random[1] = byte(t >> 16) + hello.random[2] = byte(t >> 8) + hello.random[3] = byte(t) _, err := io.ReadFull(config.Rand, hello.random[4:]) if err != nil { - h.error(alertInternalError) - return + return c.sendAlert(alertInternalError) } finishedHash.Write(hello.marshal()) - writeChan <- writerSetVersion{defaultMajor, defaultMinor} - writeChan <- hello + c.writeRecord(recordTypeHandshake, hello.marshal()) - serverHello, ok := h.readHandshakeMsg().(*serverHelloMsg) + msg, err := c.readHandshake() + if err != nil { + return err + } + serverHello, ok := msg.(*serverHelloMsg) if !ok { - h.error(alertUnexpectedMessage) - return + return c.sendAlert(alertUnexpectedMessage) } finishedHash.Write(serverHello.marshal()) - major, minor, ok := mutualVersion(serverHello.major, serverHello.minor) + + vers, ok := mutualVersion(serverHello.vers) if !ok { - h.error(alertProtocolVersion) - return + c.sendAlert(alertProtocolVersion) } - - writeChan <- writerSetVersion{major, minor} + c.vers = vers + c.haveVers = true if serverHello.cipherSuite != TLS_RSA_WITH_RC4_128_SHA || serverHello.compressionMethod != compressionNone { - h.error(alertUnexpectedMessage) - return + return c.sendAlert(alertUnexpectedMessage) } - certMsg, ok := h.readHandshakeMsg().(*certificateMsg) + msg, err = c.readHandshake() + if err != nil { + return err + } + certMsg, ok := msg.(*certificateMsg) if !ok || len(certMsg.certificates) == 0 { - h.error(alertUnexpectedMessage) - return + return c.sendAlert(alertUnexpectedMessage) } finishedHash.Write(certMsg.marshal()) @@ -87,139 +76,98 @@ func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<- for i, asn1Data := range certMsg.certificates { cert, err := x509.ParseCertificate(asn1Data) if err != nil { - h.error(alertBadCertificate) - return + return c.sendAlert(alertBadCertificate) } certs[i] = cert } // TODO(agl): do better validation of certs: max path length, name restrictions etc. for i := 1; i < len(certs); i++ { - if certs[i-1].CheckSignatureFrom(certs[i]) != nil { - h.error(alertBadCertificate) - return + if err := certs[i-1].CheckSignatureFrom(certs[i]); err != nil { + return c.sendAlert(alertBadCertificate) } } - if config.RootCAs != nil { + // TODO(rsc): Find certificates for OS X 10.6. + if false && config.RootCAs != nil { root := config.RootCAs.FindParent(certs[len(certs)-1]) if root == nil { - h.error(alertBadCertificate) - return + return c.sendAlert(alertBadCertificate) } if certs[len(certs)-1].CheckSignatureFrom(root) != nil { - h.error(alertBadCertificate) - return + return c.sendAlert(alertBadCertificate) } } pub, ok := certs[0].PublicKey.(*rsa.PublicKey) if !ok { - h.error(alertUnsupportedCertificate) - return + return c.sendAlert(alertUnsupportedCertificate) } - shd, ok := h.readHandshakeMsg().(*serverHelloDoneMsg) + msg, err = c.readHandshake() + if err != nil { + return err + } + shd, ok := msg.(*serverHelloDoneMsg) if !ok { - h.error(alertUnexpectedMessage) - return + return c.sendAlert(alertUnexpectedMessage) } finishedHash.Write(shd.marshal()) ckx := new(clientKeyExchangeMsg) preMasterSecret := make([]byte, 48) - // Note that the version number in the preMasterSecret must be the - // version offered in the ClientHello. - preMasterSecret[0] = defaultMajor - preMasterSecret[1] = defaultMinor + preMasterSecret[0] = byte(hello.vers >> 8) + preMasterSecret[1] = byte(hello.vers) _, err = io.ReadFull(config.Rand, preMasterSecret[2:]) if err != nil { - h.error(alertInternalError) - return + return c.sendAlert(alertInternalError) } ckx.ciphertext, err = rsa.EncryptPKCS1v15(config.Rand, pub, preMasterSecret) if err != nil { - h.error(alertInternalError) - return + return c.sendAlert(alertInternalError) } finishedHash.Write(ckx.marshal()) - writeChan <- ckx + c.writeRecord(recordTypeHandshake, ckx.marshal()) suite := cipherSuites[0] masterSecret, clientMAC, serverMAC, clientKey, serverKey := keysFromPreMasterSecret11(preMasterSecret, hello.random, serverHello.random, suite.hashLength, suite.cipherKeyLength) cipher, _ := rc4.NewCipher(clientKey) - writeChan <- writerChangeCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)} + + c.out.prepareCipherSpec(cipher, hmac.New(sha1.New(), clientMAC)) + c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) finished := new(finishedMsg) finished.verifyData = finishedHash.clientSum(masterSecret) finishedHash.Write(finished.marshal()) - writeChan <- finished - - // TODO(agl): this is cut-through mode which should probably be an option. - writeChan <- writerEnableApplicationData{} - - _, ok = h.readHandshakeMsg().(changeCipherSpec) - if !ok { - h.error(alertUnexpectedMessage) - return - } + c.writeRecord(recordTypeHandshake, finished.marshal()) cipher2, _ := rc4.NewCipher(serverKey) - controlChan <- &newCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)} + c.in.prepareCipherSpec(cipher2, hmac.New(sha1.New(), serverMAC)) + c.readRecord(recordTypeChangeCipherSpec) + if c.err != nil { + return c.err + } - serverFinished, ok := h.readHandshakeMsg().(*finishedMsg) + msg, err = c.readHandshake() + if err != nil { + return err + } + serverFinished, ok := msg.(*finishedMsg) if !ok { - h.error(alertUnexpectedMessage) - return + return c.sendAlert(alertUnexpectedMessage) } verify := finishedHash.serverSum(masterSecret) if len(verify) != len(serverFinished.verifyData) || subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { - h.error(alertHandshakeFailure) - return + return c.sendAlert(alertHandshakeFailure) } - controlChan <- ConnectionState{HandshakeComplete: true, CipherSuite: "TLS_RSA_WITH_RC4_128_SHA"} - - // This should just block forever. - _ = h.readHandshakeMsg() - h.error(alertUnexpectedMessage) - return -} - -func (h *clientHandshake) readHandshakeMsg() interface{} { - v := <-h.msgChan - if closed(h.msgChan) { - // If the channel closed then the processor received an error - // from the peer and we don't want to echo it back to them. - h.msgChan = nil - return 0 - } - if _, ok := v.(alert); ok { - // We got an alert from the processor. We forward to the writer - // and shutdown. - h.writeChan <- v - h.msgChan = nil - return 0 - } - return v -} - -func (h *clientHandshake) error(e alertType) { - if h.msgChan != nil { - // If we didn't get an error from the processor, then we need - // to tell it about the error. - go func() { - for _ = range h.msgChan { - } - }() - h.controlChan <- ConnectionState{Error: e} - close(h.controlChan) - h.writeChan <- alert{alertLevelError, e} - } + c.handshakeComplete = true + c.cipherSuite = TLS_RSA_WITH_RC4_128_SHA + return nil } diff --git a/src/pkg/crypto/tls/handshake_messages.go b/src/pkg/crypto/tls/handshake_messages.go index 966314857..f0a48c863 100644 --- a/src/pkg/crypto/tls/handshake_messages.go +++ b/src/pkg/crypto/tls/handshake_messages.go @@ -6,7 +6,7 @@ package tls type clientHelloMsg struct { raw []byte - major, minor uint8 + vers uint16 random []byte sessionId []byte cipherSuites []uint16 @@ -40,8 +40,8 @@ func (m *clientHelloMsg) marshal() []byte { x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) x[3] = uint8(length) - x[4] = m.major - x[5] = m.minor + x[4] = uint8(m.vers >> 8) + x[5] = uint8(m.vers) copy(x[6:38], m.random) x[38] = uint8(len(m.sessionId)) copy(x[39:39+len(m.sessionId)], m.sessionId) @@ -108,12 +108,11 @@ func (m *clientHelloMsg) marshal() []byte { } func (m *clientHelloMsg) unmarshal(data []byte) bool { - if len(data) < 43 { + if len(data) < 42 { return false } m.raw = data - m.major = data[4] - m.minor = data[5] + m.vers = uint16(data[4])<<8 | uint16(data[5]) m.random = data[6:38] sessionIdLen := int(data[38]) if sessionIdLen > 32 || len(data) < 39+sessionIdLen { @@ -136,7 +135,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i]) } data = data[2+cipherSuiteLen:] - if len(data) < 2 { + if len(data) < 1 { return false } compressionMethodsLen := int(data[0]) @@ -212,7 +211,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { type serverHelloMsg struct { raw []byte - major, minor uint8 + vers uint16 random []byte sessionId []byte cipherSuite uint16 @@ -249,8 +248,8 @@ func (m *serverHelloMsg) marshal() []byte { x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) x[3] = uint8(length) - x[4] = m.major - x[5] = m.minor + x[4] = uint8(m.vers >> 8) + x[5] = uint8(m.vers) copy(x[6:38], m.random) x[38] = uint8(len(m.sessionId)) copy(x[39:39+len(m.sessionId)], m.sessionId) @@ -306,8 +305,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { return false } m.raw = data - m.major = data[4] - m.minor = data[5] + m.vers = uint16(data[4])<<8 | uint16(data[5]) m.random = data[6:38] sessionIdLen := int(data[38]) if sessionIdLen > 32 || len(data) < 39+sessionIdLen { diff --git a/src/pkg/crypto/tls/handshake_messages_test.go b/src/pkg/crypto/tls/handshake_messages_test.go index 3c5902e24..2e422cc6a 100644 --- a/src/pkg/crypto/tls/handshake_messages_test.go +++ b/src/pkg/crypto/tls/handshake_messages_test.go @@ -97,8 +97,7 @@ func randomString(n int, rand *rand.Rand) string { func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &clientHelloMsg{} - m.major = uint8(rand.Intn(256)) - m.minor = uint8(rand.Intn(256)) + m.vers = uint16(rand.Intn(65536)) m.random = randomBytes(32, rand) m.sessionId = randomBytes(rand.Intn(32), rand) m.cipherSuites = make([]uint16, rand.Intn(63)+1) @@ -118,8 +117,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &serverHelloMsg{} - m.major = uint8(rand.Intn(256)) - m.minor = uint8(rand.Intn(256)) + m.vers = uint16(rand.Intn(65536)) m.random = randomBytes(32, rand) m.sessionId = randomBytes(rand.Intn(32), rand) m.cipherSuite = uint16(rand.Int31()) diff --git a/src/pkg/crypto/tls/handshake_server.go b/src/pkg/crypto/tls/handshake_server.go index 50854d154..ebf956763 100644 --- a/src/pkg/crypto/tls/handshake_server.go +++ b/src/pkg/crypto/tls/handshake_server.go @@ -19,6 +19,7 @@ import ( "crypto/sha1" "crypto/subtle" "io" + "os" ) type cipherSuite struct { @@ -31,33 +32,22 @@ var cipherSuites = []cipherSuite{ cipherSuite{TLS_RSA_WITH_RC4_128_SHA, 20, 16}, } -// A serverHandshake performs the server side of the TLS 1.1 handshake protocol. -type serverHandshake struct { - writeChan chan<- interface{} - controlChan chan<- interface{} - msgChan <-chan interface{} - config *Config -} - -func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) { - h.writeChan = writeChan - h.controlChan = controlChan - h.msgChan = msgChan - h.config = config - - defer close(writeChan) - defer close(controlChan) - - clientHello, ok := h.readHandshakeMsg().(*clientHelloMsg) +func (c *Conn) serverHandshake() os.Error { + config := c.config + msg, err := c.readHandshake() + if err != nil { + return err + } + clientHello, ok := msg.(*clientHelloMsg) if !ok { - h.error(alertUnexpectedMessage) - return + return c.sendAlert(alertUnexpectedMessage) } - major, minor, ok := mutualVersion(clientHello.major, clientHello.minor) + vers, ok := mutualVersion(clientHello.vers) if !ok { - h.error(alertProtocolVersion) - return + return c.sendAlert(alertProtocolVersion) } + c.vers = vers + c.haveVers = true finishedHash := newFinishedHash() finishedHash.Write(clientHello.marshal()) @@ -89,23 +79,20 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<- } if suite == nil || !foundCompression { - h.error(alertHandshakeFailure) - return + return c.sendAlert(alertHandshakeFailure) } - hello.major = major - hello.minor = minor + hello.vers = vers hello.cipherSuite = suite.id - currentTime := uint32(config.Time()) + t := uint32(config.Time()) hello.random = make([]byte, 32) - hello.random[0] = byte(currentTime >> 24) - hello.random[1] = byte(currentTime >> 16) - hello.random[2] = byte(currentTime >> 8) - hello.random[3] = byte(currentTime) - _, err := io.ReadFull(config.Rand, hello.random[4:]) + hello.random[0] = byte(t >> 24) + hello.random[1] = byte(t >> 16) + hello.random[2] = byte(t >> 8) + hello.random[3] = byte(t) + _, err = io.ReadFull(config.Rand, hello.random[4:]) if err != nil { - h.error(alertInternalError) - return + return c.sendAlert(alertInternalError) } hello.compressionMethod = compressionNone if clientHello.nextProtoNeg { @@ -114,41 +101,40 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<- } finishedHash.Write(hello.marshal()) - writeChan <- writerSetVersion{major, minor} - writeChan <- hello + c.writeRecord(recordTypeHandshake, hello.marshal()) if len(config.Certificates) == 0 { - h.error(alertInternalError) - return + return c.sendAlert(alertInternalError) } certMsg := new(certificateMsg) certMsg.certificates = config.Certificates[0].Certificate finishedHash.Write(certMsg.marshal()) - writeChan <- certMsg + c.writeRecord(recordTypeHandshake, certMsg.marshal()) helloDone := new(serverHelloDoneMsg) finishedHash.Write(helloDone.marshal()) - writeChan <- helloDone + c.writeRecord(recordTypeHandshake, helloDone.marshal()) - ckx, ok := h.readHandshakeMsg().(*clientKeyExchangeMsg) + msg, err = c.readHandshake() + if err != nil { + return err + } + ckx, ok := msg.(*clientKeyExchangeMsg) if !ok { - h.error(alertUnexpectedMessage) - return + return c.sendAlert(alertUnexpectedMessage) } finishedHash.Write(ckx.marshal()) preMasterSecret := make([]byte, 48) _, err = io.ReadFull(config.Rand, preMasterSecret[2:]) if err != nil { - h.error(alertInternalError) - return + return c.sendAlert(alertInternalError) } err = rsa.DecryptPKCS1v15SessionKey(config.Rand, config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret) if err != nil { - h.error(alertHandshakeFailure) - return + return c.sendAlert(alertHandshakeFailure) } // We don't check the version number in the premaster secret. For one, // by checking it, we would leak information about the validity of the @@ -160,91 +146,53 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<- masterSecret, clientMAC, serverMAC, clientKey, serverKey := keysFromPreMasterSecret11(preMasterSecret, clientHello.random, hello.random, suite.hashLength, suite.cipherKeyLength) - _, ok = h.readHandshakeMsg().(changeCipherSpec) - if !ok { - h.error(alertUnexpectedMessage) - return - } - cipher, _ := rc4.NewCipher(clientKey) - controlChan <- &newCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)} + c.in.prepareCipherSpec(cipher, hmac.New(sha1.New(), clientMAC)) + c.readRecord(recordTypeChangeCipherSpec) + if err := c.error(); err != nil { + return err + } - clientProtocol := "" if hello.nextProtoNeg { - nextProto, ok := h.readHandshakeMsg().(*nextProtoMsg) + msg, err = c.readHandshake() + if err != nil { + return err + } + nextProto, ok := msg.(*nextProtoMsg) if !ok { - h.error(alertUnexpectedMessage) - return + return c.sendAlert(alertUnexpectedMessage) } finishedHash.Write(nextProto.marshal()) - clientProtocol = nextProto.proto + c.clientProtocol = nextProto.proto } - clientFinished, ok := h.readHandshakeMsg().(*finishedMsg) + msg, err = c.readHandshake() + if err != nil { + return err + } + clientFinished, ok := msg.(*finishedMsg) if !ok { - h.error(alertUnexpectedMessage) - return + return c.sendAlert(alertUnexpectedMessage) } verify := finishedHash.clientSum(masterSecret) if len(verify) != len(clientFinished.verifyData) || subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { - h.error(alertHandshakeFailure) - return + return c.sendAlert(alertHandshakeFailure) } - controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0, clientProtocol} - finishedHash.Write(clientFinished.marshal()) cipher2, _ := rc4.NewCipher(serverKey) - writeChan <- writerChangeCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)} + c.out.prepareCipherSpec(cipher2, hmac.New(sha1.New(), serverMAC)) + c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) finished := new(finishedMsg) finished.verifyData = finishedHash.serverSum(masterSecret) - writeChan <- finished - - writeChan <- writerEnableApplicationData{} - - for { - _, ok := h.readHandshakeMsg().(*clientHelloMsg) - if !ok { - h.error(alertUnexpectedMessage) - return - } - // We reject all renegotication requests. - writeChan <- alert{alertLevelWarning, alertNoRenegotiation} - } -} + c.writeRecord(recordTypeHandshake, finished.marshal()) -func (h *serverHandshake) readHandshakeMsg() interface{} { - v := <-h.msgChan - if closed(h.msgChan) { - // If the channel closed then the processor received an error - // from the peer and we don't want to echo it back to them. - h.msgChan = nil - return 0 - } - if _, ok := v.(alert); ok { - // We got an alert from the processor. We forward to the writer - // and shutdown. - h.writeChan <- v - h.msgChan = nil - return 0 - } - return v -} + c.handshakeComplete = true + c.cipherSuite = TLS_RSA_WITH_RC4_128_SHA -func (h *serverHandshake) error(e alertType) { - if h.msgChan != nil { - // If we didn't get an error from the processor, then we need - // to tell it about the error. - go func() { - for _ = range h.msgChan { - } - }() - h.controlChan <- ConnectionState{false, "", e, ""} - close(h.controlChan) - h.writeChan <- alert{alertLevelError, e} - } + return nil } diff --git a/src/pkg/crypto/tls/handshake_server_test.go b/src/pkg/crypto/tls/handshake_server_test.go index a580b14e3..d31dc497e 100644 --- a/src/pkg/crypto/tls/handshake_server_test.go +++ b/src/pkg/crypto/tls/handshake_server_test.go @@ -5,12 +5,16 @@ package tls import ( - "bytes" + // "bytes" "big" "crypto/rsa" + "encoding/hex" + "flag" + "io" + "net" "os" "testing" - "testing/script" + // "testing/script" ) type zeroSource struct{} @@ -34,29 +38,23 @@ func init() { testConfig.Certificates[0].PrivateKey = testPrivateKey } -func setupServerHandshake() (writeChan chan interface{}, controlChan chan interface{}, msgChan chan interface{}) { - sh := new(serverHandshake) - writeChan = make(chan interface{}) - controlChan = make(chan interface{}) - msgChan = make(chan interface{}) - - go sh.loop(writeChan, controlChan, msgChan, testConfig) - return -} - -func testClientHelloFailure(t *testing.T, clientHello interface{}, expectedAlert alertType) { - writeChan, controlChan, msgChan := setupServerHandshake() - defer close(msgChan) - - send := script.NewEvent("send", nil, script.Send{msgChan, clientHello}) - recvAlert := script.NewEvent("recv alert", []*script.Event{send}, script.Recv{writeChan, alert{alertLevelError, expectedAlert}}) - close1 := script.NewEvent("msgChan close", []*script.Event{recvAlert}, script.Closed{writeChan}) - recvState := script.NewEvent("recv state", []*script.Event{send}, script.Recv{controlChan, ConnectionState{false, "", expectedAlert, ""}}) - close2 := script.NewEvent("controlChan close", []*script.Event{recvState}, script.Closed{controlChan}) - - err := script.Perform(0, []*script.Event{send, recvAlert, close1, recvState, close2}) - if err != nil { - t.Errorf("Got error: %s", err) +func testClientHelloFailure(t *testing.T, m handshakeMessage, expected os.Error) { + // Create in-memory network connection, + // send message to server. Should return + // expected error. + c, s := net.Pipe() + go func() { + cli := Client(c, testConfig) + if ch, ok := m.(*clientHelloMsg); ok { + cli.vers = ch.vers + } + cli.writeRecord(recordTypeHandshake, m.marshal()) + c.Close() + }() + err := Server(s, testConfig).Handshake() + s.Close() + if e, ok := err.(*net.OpError); !ok || e.Error != expected { + t.Errorf("Got error: %s; expected: %s", err, expected) } } @@ -64,134 +62,100 @@ func TestSimpleError(t *testing.T) { testClientHelloFailure(t, &serverHelloDoneMsg{}, alertUnexpectedMessage) } -var badProtocolVersions = []uint8{0, 0, 0, 5, 1, 0, 1, 5, 2, 0, 2, 5, 3, 0} +var badProtocolVersions = []uint16{0x0000, 0x0005, 0x0100, 0x0105, 0x0200, 0x0205, 0x0300} func TestRejectBadProtocolVersion(t *testing.T) { - clientHello := new(clientHelloMsg) - - for i := 0; i < len(badProtocolVersions); i += 2 { - clientHello.major = badProtocolVersions[i] - clientHello.minor = badProtocolVersions[i+1] - - testClientHelloFailure(t, clientHello, alertProtocolVersion) + for _, v := range badProtocolVersions { + testClientHelloFailure(t, &clientHelloMsg{vers: v}, alertProtocolVersion) } } func TestNoSuiteOverlap(t *testing.T) { - clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{0xff00}, []uint8{0}, false, ""} + clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{0xff00}, []uint8{0}, false, ""} testClientHelloFailure(t, clientHello, alertHandshakeFailure) } func TestNoCompressionOverlap(t *testing.T) { - clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, ""} + clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, ""} testClientHelloFailure(t, clientHello, alertHandshakeFailure) } -func matchServerHello(v interface{}) bool { - serverHello, ok := v.(*serverHelloMsg) - if !ok { - return false - } - return serverHello.major == 3 && - serverHello.minor == 2 && - serverHello.cipherSuite == TLS_RSA_WITH_RC4_128_SHA && - serverHello.compressionMethod == compressionNone -} - func TestAlertForwarding(t *testing.T) { - writeChan, controlChan, msgChan := setupServerHandshake() - defer close(msgChan) - - a := alert{alertLevelError, alertNoRenegotiation} - sendAlert := script.NewEvent("send alert", nil, script.Send{msgChan, a}) - recvAlert := script.NewEvent("recv alert", []*script.Event{sendAlert}, script.Recv{writeChan, a}) - closeWriter := script.NewEvent("close writer", []*script.Event{recvAlert}, script.Closed{writeChan}) - closeControl := script.NewEvent("close control", []*script.Event{recvAlert}, script.Closed{controlChan}) - - err := script.Perform(0, []*script.Event{sendAlert, recvAlert, closeWriter, closeControl}) - if err != nil { - t.Errorf("Got error: %s", err) + c, s := net.Pipe() + go func() { + Client(c, testConfig).sendAlert(alertUnknownCA) + c.Close() + }() + + err := Server(s, testConfig).Handshake() + s.Close() + if e, ok := err.(*net.OpError); !ok || e.Error != os.Error(alertUnknownCA) { + t.Errorf("Got error: %s; expected: %s", err, alertUnknownCA) } } func TestClose(t *testing.T) { - writeChan, controlChan, msgChan := setupServerHandshake() - - close := script.NewEvent("close", nil, script.Close{msgChan}) - closed1 := script.NewEvent("closed1", []*script.Event{close}, script.Closed{writeChan}) - closed2 := script.NewEvent("closed2", []*script.Event{close}, script.Closed{controlChan}) - - err := script.Perform(0, []*script.Event{close, closed1, closed2}) - if err != nil { - t.Errorf("Got error: %s", err) - } -} + c, s := net.Pipe() + go c.Close() -func matchCertificate(v interface{}) bool { - cert, ok := v.(*certificateMsg) - if !ok { - return false + err := Server(s, testConfig).Handshake() + s.Close() + if err != os.EOF { + t.Errorf("Got error: %s; expected: %s", err, os.EOF) } - return len(cert.certificates) == 1 && - bytes.Compare(cert.certificates[0], testCertificate) == 0 } -func matchSetCipher(v interface{}) bool { - _, ok := v.(writerChangeCipherSpec) - return ok -} -func matchDone(v interface{}) bool { - _, ok := v.(*serverHelloDoneMsg) - return ok -} +func TestHandshakeServer(t *testing.T) { + c, s := net.Pipe() + srv := Server(s, testConfig) + go func() { + srv.Write([]byte("hello, world\n")) + srv.Close() + }() + + defer c.Close() + for i, b := range serverScript { + if i%2 == 0 { + c.Write(b) + continue + } + bb := make([]byte, len(b)) + _, err := io.ReadFull(c, bb) + if err != nil { + t.Fatalf("#%d: %s", i, err) + } + } -func matchFinished(v interface{}) bool { - finished, ok := v.(*finishedMsg) - if !ok { - return false + if !srv.haveVers || srv.vers != 0x0302 { + t.Errorf("server version incorrect: %v %v", srv.haveVers, srv.vers) } - return bytes.Compare(finished.verifyData, fromHex("29122ae11453e631487b02ed")) == 0 -} -func matchNewCipherSpec(v interface{}) bool { - _, ok := v.(*newCipherSpec) - return ok + // TODO: check protocol } -func TestFullHandshake(t *testing.T) { - writeChan, controlChan, msgChan := setupServerHandshake() - defer close(msgChan) - - // The values for this test were obtained from running `gnutls-cli --insecure --debug 9` - clientHello := &clientHelloMsg{fromHex("0100007603024aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b310000340033004500390088001600320044003800870013006600900091008f008e002f004100350084000a00050004008c008d008b008a01000019000900030200010000000e000c0000093132372e302e302e31"), 3, 2, fromHex("4aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b31"), nil, []uint16{0x33, 0x45, 0x39, 0x88, 0x16, 0x32, 0x44, 0x38, 0x87, 0x13, 0x66, 0x90, 0x91, 0x8f, 0x8e, 0x2f, 0x41, 0x35, 0x84, 0xa, 0x5, 0x4, 0x8c, 0x8d, 0x8b, 0x8a}, []uint8{0x0}, false, ""} +var serve = flag.Bool("serve", false, "run a TLS server on :10443") - sendHello := script.NewEvent("send hello", nil, script.Send{msgChan, clientHello}) - setVersion := script.NewEvent("set version", []*script.Event{sendHello}, script.Recv{writeChan, writerSetVersion{3, 2}}) - recvHello := script.NewEvent("recv hello", []*script.Event{setVersion}, script.RecvMatch{writeChan, matchServerHello}) - recvCert := script.NewEvent("recv cert", []*script.Event{recvHello}, script.RecvMatch{writeChan, matchCertificate}) - recvDone := script.NewEvent("recv done", []*script.Event{recvCert}, script.RecvMatch{writeChan, matchDone}) - - ckx := &clientKeyExchangeMsg{nil, fromHex("872e1fee5f37dd86f3215938ac8de20b302b90074e9fb93097e6b7d1286d0f45abf2daf179deb618bb3c70ed0afee6ee24476ee4649e5a23358143c0f1d9c251")} - sendCKX := script.NewEvent("send ckx", []*script.Event{recvDone}, script.Send{msgChan, ckx}) - - sendCCS := script.NewEvent("send ccs", []*script.Event{sendCKX}, script.Send{msgChan, changeCipherSpec{}}) - recvNCS := script.NewEvent("recv done", []*script.Event{sendCCS}, script.RecvMatch{controlChan, matchNewCipherSpec}) - - finished := &finishedMsg{nil, fromHex("c8faca5d242f4423325c5b1a")} - sendFinished := script.NewEvent("send finished", []*script.Event{recvNCS}, script.Send{msgChan, finished}) - recvFinished := script.NewEvent("recv finished", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchFinished}) - setCipher := script.NewEvent("set cipher", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchSetCipher}) - recvConnectionState := script.NewEvent("recv state", []*script.Event{sendFinished}, script.Recv{controlChan, ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0, ""}}) +func TestRunServer(t *testing.T) { + if !*serve { + return + } - err := script.Perform(0, []*script.Event{sendHello, setVersion, recvHello, recvCert, recvDone, sendCKX, sendCCS, recvNCS, sendFinished, setCipher, recvConnectionState, recvFinished}) + l, err := Listen("tcp", ":10443", testConfig) if err != nil { - t.Errorf("Got error: %s", err) + t.Fatal(err) } -} -var testCertificate = fromHex("3082025930820203a003020102020900c2ec326b95228959300d06092a864886f70d01010505003054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374301e170d3039313032303232323434355a170d3130313032303232323434355a3054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374305c300d06092a864886f70d0101010500034b003048024100b2990f49c47dfa8cd400ae6a4d1b8a3b6a13642b23f28b003bfb97790ade9a4cc82b8b2a81747ddec08b6296e53a08c331687ef25c4bf4936ba1c0e6041e9d150203010001a381b73081b4301d0603551d0e0416041478a06086837c9293a8c9b70c0bdabdb9d77eeedf3081840603551d23047d307b801478a06086837c9293a8c9b70c0bdabdb9d77eeedfa158a4563054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374820900c2ec326b95228959300c0603551d13040530030101ff300d06092a864886f70d0101050500034100ac23761ae1349d85a439caad4d0b932b09ea96de1917c3e0507c446f4838cb3076fb4d431db8c1987e96f1d7a8a2054dea3a64ec99a3f0eda4d47a163bf1f6ac") + for { + c, err := l.Accept() + if err != nil { + break + } + c.Write([]byte("hello, world\n")) + c.Close() + } +} func bigFromString(s string) *big.Int { ret := new(big.Int) @@ -199,12 +163,131 @@ func bigFromString(s string) *big.Int { return ret } +func fromHex(s string) []byte { + b, _ := hex.DecodeString(s) + return b +} + +var testCertificate = fromHex("308202b030820219a00302010202090085b0bba48a7fb8ca300d06092a864886f70d01010505003045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464301e170d3130303432343039303933385a170d3131303432343039303933385a3045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746430819f300d06092a864886f70d010101050003818d0030818902818100bb79d6f517b5e5bf4610d0dc69bee62b07435ad0032d8a7a4385b71452e7a5654c2c78b8238cb5b482e5de1f953b7e62a52ca533d6fe125c7a56fcf506bffa587b263fb5cd04d3d0c921964ac7f4549f5abfef427100fe1899077f7e887d7df10439c4a22edb51c97ce3c04c3b326601cfafb11db8719a1ddbdb896baeda2d790203010001a381a73081a4301d0603551d0e04160414b1ade2855acfcb28db69ce2369ded3268e18883930750603551d23046e306c8014b1ade2855acfcb28db69ce2369ded3268e188839a149a4473045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746482090085b0bba48a7fb8ca300c0603551d13040530030101ff300d06092a864886f70d010105050003818100086c4524c76bb159ab0c52ccf2b014d7879d7a6475b55a9566e4c52b8eae12661feb4f38b36e60d392fdf74108b52513b1187a24fb301dbaed98b917ece7d73159db95d31d78ea50565cd5825a2d5a5f33c4b6d8c97590968c0f5298b5cd981f89205ff2a01ca31b9694dda9fd57e970e8266d71999b266e3850296c90a7bdd9") + var testPrivateKey = &rsa.PrivateKey{ PublicKey: rsa.PublicKey{ - N: bigFromString("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"), + N: bigFromString("131650079503776001033793877885499001334664249354723305978524647182322416328664556247316495448366990052837680518067798333412266673813370895702118944398081598789828837447552603077848001020611640547221687072142537202428102790818451901395596882588063427854225330436740647715202971973145151161964464812406232198521"), E: 65537, }, - D: bigFromString("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"), - P: bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"), - Q: bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"), + D: bigFromString("29354450337804273969007277378287027274721892607543397931919078829901848876371746653677097639302788129485893852488285045793268732234230875671682624082413996177431586734171663258657462237320300610850244186316880055243099640544518318093544057213190320837094958164973959123058337475052510833916491060913053867729"), + P: bigFromString("11969277782311800166562047708379380720136961987713178380670422671426759650127150688426177829077494755200794297055316163155755835813760102405344560929062149"), + Q: bigFromString("10998999429884441391899182616418192492905073053684657075974935218461686523870125521822756579792315215543092255516093840728890783887287417039645833477273829"), +} + +// Script of interaction with gnutls implementation. +// The values for this test are obtained by building a test binary (gotest) +// and then running 6.out -serve to start a server and then +// gnutls-cli --insecure --debug 100 -p 10443 localhost +// to dump a session. +var serverScript = [][]byte{ + // Alternate write and read. + []byte{ + 0x16, 0x03, 0x02, 0x00, 0x71, 0x01, 0x00, 0x00, 0x6d, 0x03, 0x02, 0x4b, 0xd4, 0xee, 0x6e, 0xab, + 0x0b, 0xc3, 0x01, 0xd6, 0x8d, 0xe0, 0x72, 0x7e, 0x6c, 0x04, 0xbe, 0x9a, 0x3c, 0xa3, 0xd8, 0x95, + 0x28, 0x00, 0xb2, 0xe8, 0x1f, 0xdd, 0xb0, 0xec, 0xca, 0x46, 0x1f, 0x00, 0x00, 0x28, 0x00, 0x33, + 0x00, 0x39, 0x00, 0x16, 0x00, 0x32, 0x00, 0x38, 0x00, 0x13, 0x00, 0x66, 0x00, 0x90, 0x00, 0x91, + 0x00, 0x8f, 0x00, 0x8e, 0x00, 0x2f, 0x00, 0x35, 0x00, 0x0a, 0x00, 0x05, 0x00, 0x04, 0x00, 0x8c, + 0x00, 0x8d, 0x00, 0x8b, 0x00, 0x8a, 0x01, 0x00, 0x00, 0x1c, 0x00, 0x09, 0x00, 0x03, 0x02, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x11, 0x00, 0x0f, 0x00, 0x00, 0x0c, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, + 0x38, 0x2e, 0x30, 0x2e, 0x31, 0x30, + }, + + []byte{ + 0x16, 0x03, 0x02, 0x00, 0x2a, + 0x02, 0x00, 0x00, 0x26, 0x03, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, + + 0x16, 0x03, 0x02, 0x02, 0xbe, + 0x0b, 0x00, 0x02, 0xba, 0x00, 0x02, 0xb7, 0x00, 0x02, 0xb4, 0x30, 0x82, 0x02, 0xb0, 0x30, 0x82, + 0x02, 0x19, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x09, 0x00, 0x85, 0xb0, 0xbb, 0xa4, 0x8a, 0x7f, + 0xb8, 0xca, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, + 0x00, 0x30, 0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x41, 0x55, + 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x18, + 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67, 0x69, 0x74, 0x73, + 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x30, 0x30, 0x34, + 0x32, 0x34, 0x30, 0x39, 0x30, 0x39, 0x33, 0x38, 0x5a, 0x17, 0x0d, 0x31, 0x31, 0x30, 0x34, 0x32, + 0x34, 0x30, 0x39, 0x30, 0x39, 0x33, 0x38, 0x5a, 0x30, 0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, + 0x55, 0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, + 0x13, 0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, + 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, + 0x57, 0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x30, + 0x81, 0x9f, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, + 0x00, 0x03, 0x81, 0x8d, 0x00, 0x30, 0x81, 0x89, 0x02, 0x81, 0x81, 0x00, 0xbb, 0x79, 0xd6, 0xf5, + 0x17, 0xb5, 0xe5, 0xbf, 0x46, 0x10, 0xd0, 0xdc, 0x69, 0xbe, 0xe6, 0x2b, 0x07, 0x43, 0x5a, 0xd0, + 0x03, 0x2d, 0x8a, 0x7a, 0x43, 0x85, 0xb7, 0x14, 0x52, 0xe7, 0xa5, 0x65, 0x4c, 0x2c, 0x78, 0xb8, + 0x23, 0x8c, 0xb5, 0xb4, 0x82, 0xe5, 0xde, 0x1f, 0x95, 0x3b, 0x7e, 0x62, 0xa5, 0x2c, 0xa5, 0x33, + 0xd6, 0xfe, 0x12, 0x5c, 0x7a, 0x56, 0xfc, 0xf5, 0x06, 0xbf, 0xfa, 0x58, 0x7b, 0x26, 0x3f, 0xb5, + 0xcd, 0x04, 0xd3, 0xd0, 0xc9, 0x21, 0x96, 0x4a, 0xc7, 0xf4, 0x54, 0x9f, 0x5a, 0xbf, 0xef, 0x42, + 0x71, 0x00, 0xfe, 0x18, 0x99, 0x07, 0x7f, 0x7e, 0x88, 0x7d, 0x7d, 0xf1, 0x04, 0x39, 0xc4, 0xa2, + 0x2e, 0xdb, 0x51, 0xc9, 0x7c, 0xe3, 0xc0, 0x4c, 0x3b, 0x32, 0x66, 0x01, 0xcf, 0xaf, 0xb1, 0x1d, + 0xb8, 0x71, 0x9a, 0x1d, 0xdb, 0xdb, 0x89, 0x6b, 0xae, 0xda, 0x2d, 0x79, 0x02, 0x03, 0x01, 0x00, + 0x01, 0xa3, 0x81, 0xa7, 0x30, 0x81, 0xa4, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16, + 0x04, 0x14, 0xb1, 0xad, 0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb, 0x69, 0xce, 0x23, 0x69, 0xde, + 0xd3, 0x26, 0x8e, 0x18, 0x88, 0x39, 0x30, 0x75, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x6e, 0x30, + 0x6c, 0x80, 0x14, 0xb1, 0xad, 0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb, 0x69, 0xce, 0x23, 0x69, + 0xde, 0xd3, 0x26, 0x8e, 0x18, 0x88, 0x39, 0xa1, 0x49, 0xa4, 0x47, 0x30, 0x45, 0x31, 0x0b, 0x30, + 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, + 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65, 0x31, + 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, + 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, + 0x74, 0x64, 0x82, 0x09, 0x00, 0x85, 0xb0, 0xbb, 0xa4, 0x8a, 0x7f, 0xb8, 0xca, 0x30, 0x0c, 0x06, + 0x03, 0x55, 0x1d, 0x13, 0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, 0x30, 0x0d, 0x06, 0x09, 0x2a, + 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x03, 0x81, 0x81, 0x00, 0x08, 0x6c, + 0x45, 0x24, 0xc7, 0x6b, 0xb1, 0x59, 0xab, 0x0c, 0x52, 0xcc, 0xf2, 0xb0, 0x14, 0xd7, 0x87, 0x9d, + 0x7a, 0x64, 0x75, 0xb5, 0x5a, 0x95, 0x66, 0xe4, 0xc5, 0x2b, 0x8e, 0xae, 0x12, 0x66, 0x1f, 0xeb, + 0x4f, 0x38, 0xb3, 0x6e, 0x60, 0xd3, 0x92, 0xfd, 0xf7, 0x41, 0x08, 0xb5, 0x25, 0x13, 0xb1, 0x18, + 0x7a, 0x24, 0xfb, 0x30, 0x1d, 0xba, 0xed, 0x98, 0xb9, 0x17, 0xec, 0xe7, 0xd7, 0x31, 0x59, 0xdb, + 0x95, 0xd3, 0x1d, 0x78, 0xea, 0x50, 0x56, 0x5c, 0xd5, 0x82, 0x5a, 0x2d, 0x5a, 0x5f, 0x33, 0xc4, + 0xb6, 0xd8, 0xc9, 0x75, 0x90, 0x96, 0x8c, 0x0f, 0x52, 0x98, 0xb5, 0xcd, 0x98, 0x1f, 0x89, 0x20, + 0x5f, 0xf2, 0xa0, 0x1c, 0xa3, 0x1b, 0x96, 0x94, 0xdd, 0xa9, 0xfd, 0x57, 0xe9, 0x70, 0xe8, 0x26, + 0x6d, 0x71, 0x99, 0x9b, 0x26, 0x6e, 0x38, 0x50, 0x29, 0x6c, 0x90, 0xa7, 0xbd, 0xd9, + 0x16, 0x03, 0x02, 0x00, 0x04, + 0x0e, 0x00, 0x00, 0x00, + }, + + []byte{ + 0x16, 0x03, 0x02, 0x00, 0x86, 0x10, 0x00, 0x00, 0x82, 0x00, 0x80, 0x3b, 0x7a, 0x9b, 0x05, 0xfd, + 0x1b, 0x0d, 0x81, 0xf0, 0xac, 0x59, 0x57, 0x4e, 0xb6, 0xf5, 0x81, 0xed, 0x52, 0x78, 0xc5, 0xff, + 0x36, 0x33, 0x9c, 0x94, 0x31, 0xc3, 0x14, 0x98, 0x5d, 0xa0, 0x49, 0x23, 0x11, 0x67, 0xdf, 0x73, + 0x1b, 0x81, 0x0b, 0xdd, 0x10, 0xda, 0xee, 0xb5, 0x68, 0x61, 0xa9, 0xb6, 0x15, 0xae, 0x1a, 0x11, + 0x31, 0x42, 0x2e, 0xde, 0x01, 0x4b, 0x81, 0x70, 0x03, 0xc8, 0x5b, 0xca, 0x21, 0x88, 0x25, 0xef, + 0x89, 0xf0, 0xb7, 0xff, 0x24, 0x32, 0xd3, 0x14, 0x76, 0xe2, 0x50, 0x5c, 0x2e, 0x75, 0x9d, 0x5c, + 0xa9, 0x80, 0x3d, 0x6f, 0xd5, 0x46, 0xd3, 0xdb, 0x42, 0x6e, 0x55, 0x81, 0x88, 0x42, 0x0e, 0x45, + 0xfe, 0x9e, 0xe4, 0x41, 0x79, 0xcf, 0x71, 0x0e, 0xed, 0x27, 0xa8, 0x20, 0x05, 0xe9, 0x7a, 0x42, + 0x4f, 0x05, 0x10, 0x2e, 0x52, 0x5d, 0x8c, 0x3c, 0x40, 0x49, 0x4c, + + 0x14, 0x03, 0x02, 0x00, 0x01, 0x01, + + 0x16, 0x03, 0x02, 0x00, 0x24, 0x8b, 0x12, 0x24, 0x06, 0xaa, 0x92, 0x74, 0xa1, 0x46, 0x6f, 0xc1, + 0x4e, 0x4a, 0xf7, 0x16, 0xdd, 0xd6, 0xe1, 0x2d, 0x37, 0x0b, 0x44, 0xba, 0xeb, 0xc4, 0x6c, 0xc7, + 0xa0, 0xb7, 0x8c, 0x9d, 0x24, 0xbd, 0x99, 0x33, 0x1e, + }, + + []byte{ + 0x14, 0x03, 0x02, 0x00, 0x01, + 0x01, + + 0x16, 0x03, 0x02, 0x00, 0x24, + 0x6e, 0xd1, 0x3e, 0x49, 0x68, 0xc1, 0xa0, 0xa5, 0xb7, 0xaf, 0xb0, 0x7c, 0x52, 0x1f, 0xf7, 0x2d, + 0x51, 0xf3, 0xa5, 0xb6, 0xf6, 0xd4, 0x18, 0x4b, 0x7a, 0xd5, 0x24, 0x1d, 0x09, 0xb6, 0x41, 0x1c, + 0x1c, 0x98, 0xf6, 0x90, + + 0x17, 0x03, 0x02, 0x00, 0x21, + 0x50, 0xb7, 0x92, 0x4f, 0xd8, 0x78, 0x29, 0xa2, 0xe7, 0xa5, 0xa6, 0xbd, 0x1a, 0x0c, 0xf1, 0x5a, + 0x6e, 0x6c, 0xeb, 0x38, 0x99, 0x9b, 0x3c, 0xfd, 0xee, 0x53, 0xe8, 0x4d, 0x7b, 0xa5, 0x5b, 0x00, + + 0xb9, + + 0x15, 0x03, 0x02, 0x00, 0x16, + 0xc7, 0xc9, 0x5a, 0x72, 0xfb, 0x02, 0xa5, 0x93, 0xdd, 0x69, 0xeb, 0x30, 0x68, 0x5e, 0xbc, 0xe0, + 0x44, 0xb9, 0x59, 0x33, 0x68, 0xa9, + }, } diff --git a/src/pkg/crypto/tls/record_process.go b/src/pkg/crypto/tls/record_process.go deleted file mode 100644 index 77470f04b..000000000 --- a/src/pkg/crypto/tls/record_process.go +++ /dev/null @@ -1,302 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -// A recordProcessor accepts reassembled records, decrypts and verifies them -// and routes them either to the handshake processor, to up to the application. -// It also accepts requests from the application for the current connection -// state, or for a notification when the state changes. - -import ( - "container/list" - "crypto/subtle" - "hash" -) - -// getConnectionState is a request from the application to get the current -// ConnectionState. -type getConnectionState struct { - reply chan<- ConnectionState -} - -// waitConnectionState is a request from the application to be notified when -// the connection state changes. -type waitConnectionState struct { - reply chan<- ConnectionState -} - -// connectionStateChange is a message from the handshake processor that the -// connection state has changed. -type connectionStateChange struct { - connState ConnectionState -} - -// changeCipherSpec is a message send to the handshake processor to signal that -// the peer is switching ciphers. -type changeCipherSpec struct{} - -// newCipherSpec is a message from the handshake processor that future -// records should be processed with a new cipher and MAC function. -type newCipherSpec struct { - encrypt encryptor - mac hash.Hash -} - -type recordProcessor struct { - decrypt encryptor - mac hash.Hash - seqNum uint64 - handshakeBuf []byte - appDataChan chan<- []byte - requestChan <-chan interface{} - controlChan <-chan interface{} - recordChan <-chan *record - handshakeChan chan<- interface{} - - // recordRead is nil when we don't wish to read any more. - recordRead <-chan *record - // appDataSend is nil when len(appData) == 0. - appDataSend chan<- []byte - // appData contains any application data queued for upstream. - appData []byte - // A list of channels waiting for connState to change. - waitQueue *list.List - connState ConnectionState - shutdown bool - header [13]byte -} - -// drainRequestChannel processes messages from the request channel until it's closed. -func drainRequestChannel(requestChan <-chan interface{}, c ConnectionState) { - for v := range requestChan { - if closed(requestChan) { - return - } - switch r := v.(type) { - case getConnectionState: - r.reply <- c - case waitConnectionState: - r.reply <- c - } - } -} - -func (p *recordProcessor) loop(appDataChan chan<- []byte, requestChan <-chan interface{}, controlChan <-chan interface{}, recordChan <-chan *record, handshakeChan chan<- interface{}) { - noop := nop{} - p.decrypt = noop - p.mac = noop - p.waitQueue = list.New() - - p.appDataChan = appDataChan - p.requestChan = requestChan - p.controlChan = controlChan - p.recordChan = recordChan - p.handshakeChan = handshakeChan - p.recordRead = recordChan - - for !p.shutdown { - select { - case p.appDataSend <- p.appData: - p.appData = nil - p.appDataSend = nil - p.recordRead = p.recordChan - case c := <-controlChan: - p.processControlMsg(c) - case r := <-requestChan: - p.processRequestMsg(r) - case r := <-p.recordRead: - p.processRecord(r) - } - } - - p.wakeWaiters() - go drainRequestChannel(p.requestChan, p.connState) - go func() { - for _ = range controlChan { - } - }() - - close(handshakeChan) - if len(p.appData) > 0 { - appDataChan <- p.appData - } - close(appDataChan) -} - -func (p *recordProcessor) processRequestMsg(requestMsg interface{}) { - if closed(p.requestChan) { - p.shutdown = true - return - } - - switch r := requestMsg.(type) { - case getConnectionState: - r.reply <- p.connState - case waitConnectionState: - if p.connState.HandshakeComplete { - r.reply <- p.connState - } - p.waitQueue.PushBack(r.reply) - } -} - -func (p *recordProcessor) processControlMsg(msg interface{}) { - connState, ok := msg.(ConnectionState) - if !ok || closed(p.controlChan) { - p.shutdown = true - return - } - - p.connState = connState - p.wakeWaiters() -} - -func (p *recordProcessor) wakeWaiters() { - for i := p.waitQueue.Front(); i != nil; i = i.Next() { - i.Value.(chan<- ConnectionState) <- p.connState - } - p.waitQueue.Init() -} - -func (p *recordProcessor) processRecord(r *record) { - if closed(p.recordChan) { - p.shutdown = true - return - } - - p.decrypt.XORKeyStream(r.payload) - if len(r.payload) < p.mac.Size() { - p.error(alertBadRecordMAC) - return - } - - fillMACHeader(&p.header, p.seqNum, len(r.payload)-p.mac.Size(), r) - p.seqNum++ - - p.mac.Reset() - p.mac.Write(p.header[0:13]) - p.mac.Write(r.payload[0 : len(r.payload)-p.mac.Size()]) - macBytes := p.mac.Sum() - - if subtle.ConstantTimeCompare(macBytes, r.payload[len(r.payload)-p.mac.Size():]) != 1 { - p.error(alertBadRecordMAC) - return - } - - switch r.contentType { - case recordTypeHandshake: - p.processHandshakeRecord(r.payload[0 : len(r.payload)-p.mac.Size()]) - case recordTypeChangeCipherSpec: - if len(r.payload) != 1 || r.payload[0] != 1 { - p.error(alertUnexpectedMessage) - return - } - - p.handshakeChan <- changeCipherSpec{} - newSpec, ok := (<-p.controlChan).(*newCipherSpec) - if !ok { - p.connState.Error = alertUnexpectedMessage - p.shutdown = true - return - } - p.decrypt = newSpec.encrypt - p.mac = newSpec.mac - p.seqNum = 0 - case recordTypeApplicationData: - if p.connState.HandshakeComplete == false { - p.error(alertUnexpectedMessage) - return - } - p.recordRead = nil - p.appData = r.payload[0 : len(r.payload)-p.mac.Size()] - p.appDataSend = p.appDataChan - default: - p.error(alertUnexpectedMessage) - return - } -} - -func (p *recordProcessor) processHandshakeRecord(data []byte) { - if p.handshakeBuf == nil { - p.handshakeBuf = data - } else { - if len(p.handshakeBuf) > maxHandshakeMsg { - p.error(alertInternalError) - return - } - newBuf := make([]byte, len(p.handshakeBuf)+len(data)) - copy(newBuf, p.handshakeBuf) - copy(newBuf[len(p.handshakeBuf):], data) - p.handshakeBuf = newBuf - } - - for len(p.handshakeBuf) >= 4 { - handshakeLen := int(p.handshakeBuf[1])<<16 | - int(p.handshakeBuf[2])<<8 | - int(p.handshakeBuf[3]) - if handshakeLen+4 > len(p.handshakeBuf) { - break - } - - bytes := p.handshakeBuf[0 : handshakeLen+4] - p.handshakeBuf = p.handshakeBuf[handshakeLen+4:] - if bytes[0] == typeFinished { - // Special case because Finished is synchronous: the - // handshake handler has to tell us if it's ok to start - // forwarding application data. - m := new(finishedMsg) - if !m.unmarshal(bytes) { - p.error(alertUnexpectedMessage) - } - p.handshakeChan <- m - var ok bool - p.connState, ok = (<-p.controlChan).(ConnectionState) - if !ok || p.connState.Error != 0 { - p.shutdown = true - return - } - } else { - msg, ok := parseHandshakeMsg(bytes) - if !ok { - p.error(alertUnexpectedMessage) - return - } - p.handshakeChan <- msg - } - } -} - -func (p *recordProcessor) error(err alertType) { - close(p.handshakeChan) - p.connState.Error = err - p.wakeWaiters() - p.shutdown = true -} - -func parseHandshakeMsg(data []byte) (interface{}, bool) { - var m interface { - unmarshal([]byte) bool - } - - switch data[0] { - case typeClientHello: - m = new(clientHelloMsg) - case typeServerHello: - m = new(serverHelloMsg) - case typeCertificate: - m = new(certificateMsg) - case typeServerHelloDone: - m = new(serverHelloDoneMsg) - case typeClientKeyExchange: - m = new(clientKeyExchangeMsg) - case typeNextProtocol: - m = new(nextProtoMsg) - default: - return nil, false - } - - ok := m.unmarshal(data) - return m, ok -} diff --git a/src/pkg/crypto/tls/record_process_test.go b/src/pkg/crypto/tls/record_process_test.go deleted file mode 100644 index fe001a2f9..000000000 --- a/src/pkg/crypto/tls/record_process_test.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "encoding/hex" - "testing" - "testing/script" -) - -func setup() (appDataChan chan []byte, requestChan chan interface{}, controlChan chan interface{}, recordChan chan *record, handshakeChan chan interface{}) { - rp := new(recordProcessor) - appDataChan = make(chan []byte) - requestChan = make(chan interface{}) - controlChan = make(chan interface{}) - recordChan = make(chan *record) - handshakeChan = make(chan interface{}) - - go rp.loop(appDataChan, requestChan, controlChan, recordChan, handshakeChan) - return -} - -func fromHex(s string) []byte { - b, _ := hex.DecodeString(s) - return b -} - -func TestNullConnectionState(t *testing.T) { - _, requestChan, controlChan, recordChan, _ := setup() - defer close(requestChan) - defer close(controlChan) - defer close(recordChan) - - // Test a simple request for the connection state. - replyChan := make(chan ConnectionState) - sendReq := script.NewEvent("send request", nil, script.Send{requestChan, getConnectionState{replyChan}}) - getReply := script.NewEvent("get reply", []*script.Event{sendReq}, script.Recv{replyChan, ConnectionState{false, "", 0, ""}}) - - err := script.Perform(0, []*script.Event{sendReq, getReply}) - if err != nil { - t.Errorf("Got error: %s", err) - } -} - -func TestWaitConnectionState(t *testing.T) { - _, requestChan, controlChan, recordChan, _ := setup() - defer close(requestChan) - defer close(controlChan) - defer close(recordChan) - - // Test that waitConnectionState doesn't get a reply until the connection state changes. - replyChan := make(chan ConnectionState) - sendReq := script.NewEvent("send request", nil, script.Send{requestChan, waitConnectionState{replyChan}}) - replyChan2 := make(chan ConnectionState) - sendReq2 := script.NewEvent("send request 2", []*script.Event{sendReq}, script.Send{requestChan, getConnectionState{replyChan2}}) - getReply2 := script.NewEvent("get reply 2", []*script.Event{sendReq2}, script.Recv{replyChan2, ConnectionState{false, "", 0, ""}}) - sendState := script.NewEvent("send state", []*script.Event{getReply2}, script.Send{controlChan, ConnectionState{true, "test", 1, ""}}) - getReply := script.NewEvent("get reply", []*script.Event{sendState}, script.Recv{replyChan, ConnectionState{true, "test", 1, ""}}) - - err := script.Perform(0, []*script.Event{sendReq, sendReq2, getReply2, sendState, getReply}) - if err != nil { - t.Errorf("Got error: %s", err) - } -} - -func TestHandshakeAssembly(t *testing.T) { - _, requestChan, controlChan, recordChan, handshakeChan := setup() - defer close(requestChan) - defer close(controlChan) - defer close(recordChan) - - // Test the reassembly of a fragmented handshake message. - send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("10000003")}}) - send2 := script.NewEvent("send 2", []*script.Event{send1}, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("0001")}}) - send3 := script.NewEvent("send 3", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("42")}}) - recvMsg := script.NewEvent("recv", []*script.Event{send3}, script.Recv{handshakeChan, &clientKeyExchangeMsg{fromHex("10000003000142"), fromHex("42")}}) - - err := script.Perform(0, []*script.Event{send1, send2, send3, recvMsg}) - if err != nil { - t.Errorf("Got error: %s", err) - } -} - -func TestEarlyApplicationData(t *testing.T) { - _, requestChan, controlChan, recordChan, handshakeChan := setup() - defer close(requestChan) - defer close(controlChan) - defer close(recordChan) - - // Test that applicaton data received before the handshake has completed results in an error. - send := script.NewEvent("send", nil, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("")}}) - recv := script.NewEvent("recv", []*script.Event{send}, script.Closed{handshakeChan}) - - err := script.Perform(0, []*script.Event{send, recv}) - if err != nil { - t.Errorf("Got error: %s", err) - } -} - -func TestApplicationData(t *testing.T) { - appDataChan, requestChan, controlChan, recordChan, handshakeChan := setup() - defer close(requestChan) - defer close(controlChan) - defer close(recordChan) - - // Test that the application data is forwarded after a successful Finished message. - send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("1400000c000000000000000000000000")}}) - recv1 := script.NewEvent("recv finished", []*script.Event{send1}, script.Recv{handshakeChan, &finishedMsg{fromHex("1400000c000000000000000000000000"), fromHex("000000000000000000000000")}}) - send2 := script.NewEvent("send connState", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{true, "", 0, ""}}) - send3 := script.NewEvent("send 2", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("0102")}}) - recv2 := script.NewEvent("recv data", []*script.Event{send3}, script.Recv{appDataChan, []byte{0x01, 0x02}}) - - err := script.Perform(0, []*script.Event{send1, recv1, send2, send3, recv2}) - if err != nil { - t.Errorf("Got error: %s", err) - } -} - -func TestInvalidChangeCipherSpec(t *testing.T) { - appDataChan, requestChan, controlChan, recordChan, handshakeChan := setup() - defer close(requestChan) - defer close(controlChan) - defer close(recordChan) - - send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeChangeCipherSpec, 0, 0, []byte{1}}}) - recv1 := script.NewEvent("recv 1", []*script.Event{send1}, script.Recv{handshakeChan, changeCipherSpec{}}) - send2 := script.NewEvent("send 2", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{false, "", 42, ""}}) - close := script.NewEvent("close 1", []*script.Event{send2}, script.Closed{appDataChan}) - close2 := script.NewEvent("close 2", []*script.Event{send2}, script.Closed{handshakeChan}) - - err := script.Perform(0, []*script.Event{send1, recv1, send2, close, close2}) - if err != nil { - t.Errorf("Got error: %s", err) - } -} diff --git a/src/pkg/crypto/tls/record_read.go b/src/pkg/crypto/tls/record_read.go deleted file mode 100644 index 682fde8b6..000000000 --- a/src/pkg/crypto/tls/record_read.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -// The record reader handles reading from the connection and reassembling TLS -// record structures. It loops forever doing this and writes the TLS records to -// it's outbound channel. On error, it closes its outbound channel. - -import ( - "io" - "bufio" -) - -// recordReader loops, reading TLS records from source and writing them to the -// given channel. The channel is closed on EOF or on error. -func recordReader(c chan<- *record, source io.Reader) { - defer close(c) - buf := bufio.NewReader(source) - - for { - var header [5]byte - n, _ := buf.Read(&header) - if n != 5 { - return - } - - recordLength := int(header[3])<<8 | int(header[4]) - if recordLength > maxTLSCiphertext { - return - } - - payload := make([]byte, recordLength) - n, _ = buf.Read(payload) - if n != recordLength { - return - } - - c <- &record{recordType(header[0]), header[1], header[2], payload} - } -} diff --git a/src/pkg/crypto/tls/record_read_test.go b/src/pkg/crypto/tls/record_read_test.go deleted file mode 100644 index f897599ad..000000000 --- a/src/pkg/crypto/tls/record_read_test.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "bytes" - "testing" - "testing/iotest" -) - -func matchRecord(r1, r2 *record) bool { - if (r1 == nil) != (r2 == nil) { - return false - } - if r1 == nil { - return true - } - return r1.contentType == r2.contentType && - r1.major == r2.major && - r1.minor == r2.minor && - bytes.Compare(r1.payload, r2.payload) == 0 -} - -type recordReaderTest struct { - in []byte - out []*record -} - -var recordReaderTests = []recordReaderTest{ - recordReaderTest{nil, nil}, - recordReaderTest{fromHex("01"), nil}, - recordReaderTest{fromHex("0102"), nil}, - recordReaderTest{fromHex("010203"), nil}, - recordReaderTest{fromHex("01020300"), nil}, - recordReaderTest{fromHex("0102030000"), []*record{&record{1, 2, 3, nil}}}, - recordReaderTest{fromHex("01020300000102030000"), []*record{&record{1, 2, 3, nil}, &record{1, 2, 3, nil}}}, - recordReaderTest{fromHex("0102030001fe0102030002feff"), []*record{&record{1, 2, 3, []byte{0xfe}}, &record{1, 2, 3, []byte{0xfe, 0xff}}}}, - recordReaderTest{fromHex("010203000001020300"), []*record{&record{1, 2, 3, nil}}}, -} - -func TestRecordReader(t *testing.T) { - for i, test := range recordReaderTests { - buf := bytes.NewBuffer(test.in) - c := make(chan *record) - go recordReader(c, buf) - matchRecordReaderOutput(t, i, test, c) - - buf = bytes.NewBuffer(test.in) - buf2 := iotest.OneByteReader(buf) - c = make(chan *record) - go recordReader(c, buf2) - matchRecordReaderOutput(t, i*2, test, c) - } -} - -func matchRecordReaderOutput(t *testing.T, i int, test recordReaderTest, c <-chan *record) { - for j, r1 := range test.out { - r2 := <-c - if r2 == nil { - t.Errorf("#%d truncated after %d values", i, j) - break - } - if !matchRecord(r1, r2) { - t.Errorf("#%d (%d) got:%#v want:%#v", i, j, r2, r1) - } - } - <-c - if !closed(c) { - t.Errorf("#%d: channel didn't close", i) - } -} diff --git a/src/pkg/crypto/tls/record_write.go b/src/pkg/crypto/tls/record_write.go deleted file mode 100644 index 5f3fb5b16..000000000 --- a/src/pkg/crypto/tls/record_write.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "fmt" - "hash" - "io" -) - -// writerEnableApplicationData is a message which instructs recordWriter to -// start reading and transmitting data from the application data channel. -type writerEnableApplicationData struct{} - -// writerChangeCipherSpec updates the encryption and MAC functions and resets -// the sequence count. -type writerChangeCipherSpec struct { - encryptor encryptor - mac hash.Hash -} - -// writerSetVersion sets the version number bytes that we included in the -// record header for future records. -type writerSetVersion struct { - major, minor uint8 -} - -// A recordWriter accepts messages from the handshake processor and -// application data. It writes them to the outgoing connection and blocks on -// writing. It doesn't read from the application data channel until the -// handshake processor has signaled that the handshake is complete. -type recordWriter struct { - writer io.Writer - encryptor encryptor - mac hash.Hash - seqNum uint64 - major, minor uint8 - shutdown bool - appChan <-chan []byte - controlChan <-chan interface{} - header [13]byte -} - -func (w *recordWriter) loop(writer io.Writer, appChan <-chan []byte, controlChan <-chan interface{}) { - w.writer = writer - w.encryptor = nop{} - w.mac = nop{} - w.appChan = appChan - w.controlChan = controlChan - - for !w.shutdown { - msg := <-controlChan - if _, ok := msg.(writerEnableApplicationData); ok { - break - } - w.processControlMessage(msg) - } - - for !w.shutdown { - // Always process control messages first. - if controlMsg, ok := <-controlChan; ok { - w.processControlMessage(controlMsg) - continue - } - - select { - case controlMsg := <-controlChan: - w.processControlMessage(controlMsg) - case appMsg := <-appChan: - w.processAppMessage(appMsg) - } - } - - if !closed(appChan) { - go func() { - for _ = range appChan { - } - }() - } - if !closed(controlChan) { - go func() { - for _ = range controlChan { - } - }() - } -} - -// fillMACHeader generates a MAC header. See RFC 4346, section 6.2.3.1. -func fillMACHeader(header *[13]byte, seqNum uint64, length int, r *record) { - header[0] = uint8(seqNum >> 56) - header[1] = uint8(seqNum >> 48) - header[2] = uint8(seqNum >> 40) - header[3] = uint8(seqNum >> 32) - header[4] = uint8(seqNum >> 24) - header[5] = uint8(seqNum >> 16) - header[6] = uint8(seqNum >> 8) - header[7] = uint8(seqNum) - header[8] = uint8(r.contentType) - header[9] = r.major - header[10] = r.minor - header[11] = uint8(length >> 8) - header[12] = uint8(length) -} - -func (w *recordWriter) writeRecord(r *record) { - w.mac.Reset() - - fillMACHeader(&w.header, w.seqNum, len(r.payload), r) - - w.mac.Write(w.header[0:13]) - w.mac.Write(r.payload) - macBytes := w.mac.Sum() - - w.encryptor.XORKeyStream(r.payload) - w.encryptor.XORKeyStream(macBytes) - - length := len(r.payload) + len(macBytes) - w.header[11] = uint8(length >> 8) - w.header[12] = uint8(length) - w.writer.Write(w.header[8:13]) - w.writer.Write(r.payload) - w.writer.Write(macBytes) - - w.seqNum++ -} - -func (w *recordWriter) processControlMessage(controlMsg interface{}) { - if controlMsg == nil { - w.shutdown = true - return - } - - switch msg := controlMsg.(type) { - case writerChangeCipherSpec: - w.writeRecord(&record{recordTypeChangeCipherSpec, w.major, w.minor, []byte{0x01}}) - w.encryptor = msg.encryptor - w.mac = msg.mac - w.seqNum = 0 - case writerSetVersion: - w.major = msg.major - w.minor = msg.minor - case alert: - w.writeRecord(&record{recordTypeAlert, w.major, w.minor, []byte{byte(msg.level), byte(msg.error)}}) - case handshakeMessage: - // TODO(agl): marshal may return a slice too large for a single record. - w.writeRecord(&record{recordTypeHandshake, w.major, w.minor, msg.marshal()}) - default: - fmt.Printf("processControlMessage: unknown %#v\n", msg) - } -} - -func (w *recordWriter) processAppMessage(appMsg []byte) { - if closed(w.appChan) { - w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, []byte{byte(alertCloseNotify)}}) - w.shutdown = true - return - } - - var done int - for done < len(appMsg) { - todo := len(appMsg) - if todo > maxTLSPlaintext { - todo = maxTLSPlaintext - } - w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, appMsg[done : done+todo]}) - done += todo - } -} diff --git a/src/pkg/crypto/tls/tls.go b/src/pkg/crypto/tls/tls.go index 5fbf850da..1a5da3ac4 100644 --- a/src/pkg/crypto/tls/tls.go +++ b/src/pkg/crypto/tls/tls.go @@ -6,158 +6,16 @@ package tls import ( - "io" "os" "net" - "time" ) -// A Conn represents a secure connection. -type Conn struct { - net.Conn - writeChan chan<- []byte - readChan <-chan []byte - requestChan chan<- interface{} - readBuf []byte - eof bool - readTimeout, writeTimeout int64 -} - -func timeout(c chan<- bool, nsecs int64) { - time.Sleep(nsecs) - c <- true -} - -func (tls *Conn) Read(p []byte) (int, os.Error) { - if len(tls.readBuf) == 0 { - if tls.eof { - return 0, os.EOF - } - - var timeoutChan chan bool - if tls.readTimeout > 0 { - timeoutChan = make(chan bool) - go timeout(timeoutChan, tls.readTimeout) - } - - select { - case b := <-tls.readChan: - tls.readBuf = b - case <-timeoutChan: - return 0, os.EAGAIN - } - - // TLS distinguishes between orderly closes and truncations. An - // orderly close is represented by a zero length slice. - if closed(tls.readChan) { - return 0, io.ErrUnexpectedEOF - } - if len(tls.readBuf) == 0 { - tls.eof = true - return 0, os.EOF - } - } - - n := copy(p, tls.readBuf) - tls.readBuf = tls.readBuf[n:] - return n, nil -} - -func (tls *Conn) Write(p []byte) (int, os.Error) { - if tls.eof || closed(tls.readChan) { - return 0, os.EOF - } - - var timeoutChan chan bool - if tls.writeTimeout > 0 { - timeoutChan = make(chan bool) - go timeout(timeoutChan, tls.writeTimeout) - } - - select { - case tls.writeChan <- p: - case <-timeoutChan: - return 0, os.EAGAIN - } - - return len(p), nil -} - -func (tls *Conn) Close() os.Error { - close(tls.writeChan) - close(tls.requestChan) - tls.eof = true - return nil -} - -func (tls *Conn) SetTimeout(nsec int64) os.Error { - tls.readTimeout = nsec - tls.writeTimeout = nsec - return nil -} - -func (tls *Conn) SetReadTimeout(nsec int64) os.Error { - tls.readTimeout = nsec - return nil -} - -func (tls *Conn) SetWriteTimeout(nsec int64) os.Error { - tls.writeTimeout = nsec - return nil -} - -func (tls *Conn) GetConnectionState() ConnectionState { - replyChan := make(chan ConnectionState) - tls.requestChan <- getConnectionState{replyChan} - return <-replyChan -} - -func (tls *Conn) WaitConnectionState() ConnectionState { - replyChan := make(chan ConnectionState) - tls.requestChan <- waitConnectionState{replyChan} - return <-replyChan -} - -type handshaker interface { - loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) -} - -// Server establishes a secure connection over the given connection and acts -// as a TLS server. -func startTLSGoroutines(conn net.Conn, h handshaker, config *Config) *Conn { - if config == nil { - config = defaultConfig() - } - tls := new(Conn) - tls.Conn = conn - - writeChan := make(chan []byte) - readChan := make(chan []byte) - requestChan := make(chan interface{}) - - tls.writeChan = writeChan - tls.readChan = readChan - tls.requestChan = requestChan - - handshakeWriterChan := make(chan interface{}) - processorHandshakeChan := make(chan interface{}) - handshakeProcessorChan := make(chan interface{}) - readerProcessorChan := make(chan *record) - - go new(recordWriter).loop(conn, writeChan, handshakeWriterChan) - go recordReader(readerProcessorChan, conn) - go new(recordProcessor).loop(readChan, requestChan, handshakeProcessorChan, readerProcessorChan, processorHandshakeChan) - go h.loop(handshakeWriterChan, handshakeProcessorChan, processorHandshakeChan, config) - - return tls -} - func Server(conn net.Conn, config *Config) *Conn { - return startTLSGoroutines(conn, new(serverHandshake), config) + return &Conn{conn: conn, config: config} } func Client(conn net.Conn, config *Config) *Conn { - return startTLSGoroutines(conn, new(clientHandshake), config) + return &Conn{conn: conn, config: config, isClient: true} } type Listener struct { @@ -180,22 +38,24 @@ func (l *Listener) Addr() net.Addr { return l.listener.Addr() } // NewListener creates a Listener which accepts connections from an inner // Listener and wraps each connection with Server. +// The configuration config must be non-nil and must have +// at least one certificate. func NewListener(listener net.Listener, config *Config) (l *Listener) { - if config == nil { - config = defaultConfig() - } l = new(Listener) l.listener = listener l.config = config return } -func Listen(network, laddr string) (net.Listener, os.Error) { +func Listen(network, laddr string, config *Config) (net.Listener, os.Error) { + if config == nil || len(config.Certificates) == 0 { + return nil, os.NewError("tls.Listen: no certificates in configuration") + } l, err := net.Listen(network, laddr) if err != nil { return nil, err } - return NewListener(l, nil), nil + return NewListener(l, config), nil } func Dial(network, laddr, raddr string) (net.Conn, os.Error) { |