summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Langley <agl@golang.org>2009-11-21 15:53:03 -0800
committerAdam Langley <agl@golang.org>2009-11-21 15:53:03 -0800
commitde7d1c897894d848ca5c6db059cd5810484da875 (patch)
tree0234997a84bac1860ad03b82785c5c3cf5857e43
parentda3ae8661a33a1b8367b4d954aa04dfa8aa220ac (diff)
downloadgolang-de7d1c897894d848ca5c6db059cd5810484da875.tar.gz
crypto/tls: add initial client implementation.
R=rsc, agl CC=golang-dev http://codereview.appspot.com/157076
-rw-r--r--src/pkg/crypto/tls/Makefile2
-rw-r--r--src/pkg/crypto/tls/ca_set.go75
-rw-r--r--src/pkg/crypto/tls/common.go4
-rw-r--r--src/pkg/crypto/tls/handshake_client.go225
-rw-r--r--src/pkg/crypto/tls/handshake_messages.go67
-rw-r--r--src/pkg/crypto/tls/handshake_messages_test.go39
-rw-r--r--src/pkg/crypto/tls/handshake_server.go4
-rw-r--r--src/pkg/crypto/tls/record_process.go8
-rw-r--r--src/pkg/crypto/tls/tls.go22
9 files changed, 439 insertions, 7 deletions
diff --git a/src/pkg/crypto/tls/Makefile b/src/pkg/crypto/tls/Makefile
index dd3df2957..0c2d33998 100644
--- a/src/pkg/crypto/tls/Makefile
+++ b/src/pkg/crypto/tls/Makefile
@@ -8,12 +8,14 @@ TARG=crypto/tls
GOFILES=\
alert.go\
common.go\
+ handshake_client.go\
handshake_messages.go\
handshake_server.go\
prf.go\
record_process.go\
record_read.go\
record_write.go\
+ ca_set.go\
tls.go\
include $(GOROOT)/src/Make.pkg
diff --git a/src/pkg/crypto/tls/ca_set.go b/src/pkg/crypto/tls/ca_set.go
new file mode 100644
index 000000000..e8cddd6f4
--- /dev/null
+++ b/src/pkg/crypto/tls/ca_set.go
@@ -0,0 +1,75 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import (
+ "crypto/x509";
+ "encoding/pem";
+)
+
+// A CASet is a set of certificates.
+type CASet struct {
+ bySubjectKeyId map[string]*x509.Certificate;
+ byName map[string]*x509.Certificate;
+}
+
+func NewCASet() *CASet {
+ return &CASet{
+ make(map[string]*x509.Certificate),
+ make(map[string]*x509.Certificate),
+ }
+}
+
+func nameToKey(name *x509.Name) string {
+ return name.Country + "/" + name.OrganizationalUnit + "/" + name.OrganizationalUnit + "/" + name.CommonName
+}
+
+// FindParent attempts to find the certificate in s which signs the given
+// certificate. If no such certificate can be found, it returns nil.
+func (s *CASet) FindParent(cert *x509.Certificate) (parent *x509.Certificate) {
+ var ok bool;
+
+ if len(cert.AuthorityKeyId) > 0 {
+ parent, ok = s.bySubjectKeyId[string(cert.AuthorityKeyId)]
+ } else {
+ parent, ok = s.byName[nameToKey(&cert.Issuer)]
+ }
+
+ if !ok {
+ return nil
+ }
+ return parent;
+}
+
+// SetFromPEM attempts to parse a series of PEM encoded root certificates. It
+// appends any certificates found to s and returns true if any certificates
+// were successfully parsed. On many Linux systems, /etc/ssl/cert.pem will
+// contains the system wide set of root CAs in a format suitable for this
+// function.
+func (s *CASet) SetFromPEM(pemCerts []byte) (ok bool) {
+ for len(pemCerts) > 0 {
+ var block *pem.Block;
+ block, pemCerts = pem.Decode(pemCerts);
+ if block == nil {
+ break
+ }
+ if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
+ continue
+ }
+
+ cert, err := x509.ParseCertificate(block.Bytes);
+ if err != nil {
+ continue
+ }
+
+ if len(cert.SubjectKeyId) > 0 {
+ s.bySubjectKeyId[string(cert.SubjectKeyId)] = cert
+ }
+ s.byName[nameToKey(&cert.Subject)] = cert;
+ ok = true;
+ }
+
+ return;
+}
diff --git a/src/pkg/crypto/tls/common.go b/src/pkg/crypto/tls/common.go
index e295cd472..e1318a893 100644
--- a/src/pkg/crypto/tls/common.go
+++ b/src/pkg/crypto/tls/common.go
@@ -17,6 +17,9 @@ const (
maxTLSCiphertext = 16384 + 2048;
// maxHandshakeMsg is the largest single handshake message that we'll buffer.
maxHandshakeMsg = 65536;
+ // defaultMajor and defaultMinor are the maximum TLS version that we support.
+ defaultMajor = 3;
+ defaultMinor = 2;
)
@@ -64,6 +67,7 @@ type Config struct {
// Time returns the current time as the number of seconds since the epoch.
Time func() int64;
Certificates []Certificate;
+ RootCAs *CASet;
}
type Certificate struct {
diff --git a/src/pkg/crypto/tls/handshake_client.go b/src/pkg/crypto/tls/handshake_client.go
new file mode 100644
index 000000000..db9bb8cb3
--- /dev/null
+++ b/src/pkg/crypto/tls/handshake_client.go
@@ -0,0 +1,225 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import (
+ "crypto/hmac";
+ "crypto/rc4";
+ "crypto/rsa";
+ "crypto/sha1";
+ "crypto/subtle";
+ "crypto/x509";
+ "io";
+)
+
+// A serverHandshake performs the server side of the TLS 1.1 handshake protocol.
+type clientHandshake struct {
+ writeChan chan<- interface{};
+ controlChan chan<- interface{};
+ msgChan <-chan interface{};
+ config *Config;
+}
+
+func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) {
+ h.writeChan = writeChan;
+ h.controlChan = controlChan;
+ h.msgChan = msgChan;
+ h.config = config;
+
+ defer close(writeChan);
+ defer close(controlChan);
+
+ finishedHash := newFinishedHash();
+
+ hello := &clientHelloMsg{
+ major: defaultMajor,
+ minor: defaultMinor,
+ cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
+ compressionMethods: []uint8{compressionNone},
+ random: make([]byte, 32),
+ };
+
+ currentTime := uint32(config.Time());
+ hello.random[0] = byte(currentTime >> 24);
+ hello.random[1] = byte(currentTime >> 16);
+ hello.random[2] = byte(currentTime >> 8);
+ hello.random[3] = byte(currentTime);
+ _, err := io.ReadFull(config.Rand, hello.random[4:len(hello.random)]);
+ if err != nil {
+ h.error(alertInternalError);
+ return;
+ }
+
+ finishedHash.Write(hello.marshal());
+ writeChan <- writerSetVersion{defaultMajor, defaultMinor};
+ writeChan <- hello;
+
+ serverHello, ok := h.readHandshakeMsg().(*serverHelloMsg);
+ if !ok {
+ h.error(alertUnexpectedMessage);
+ return;
+ }
+ finishedHash.Write(serverHello.marshal());
+ major, minor, ok := mutualVersion(serverHello.major, serverHello.minor);
+ if !ok {
+ h.error(alertProtocolVersion);
+ return;
+ }
+
+ writeChan <- writerSetVersion{major, minor};
+
+ if serverHello.cipherSuite != TLS_RSA_WITH_RC4_128_SHA ||
+ serverHello.compressionMethod != compressionNone {
+ h.error(alertUnexpectedMessage);
+ return;
+ }
+
+ certMsg, ok := h.readHandshakeMsg().(*certificateMsg);
+ if !ok || len(certMsg.certificates) == 0 {
+ h.error(alertUnexpectedMessage);
+ return;
+ }
+ finishedHash.Write(certMsg.marshal());
+
+ certs := make([]*x509.Certificate, len(certMsg.certificates));
+ for i, asn1Data := range certMsg.certificates {
+ cert, err := x509.ParseCertificate(asn1Data);
+ if err != nil {
+ h.error(alertBadCertificate);
+ return;
+ }
+ certs[i] = cert;
+ }
+
+ // TODO(agl): do better validation of certs: max path length, name restrictions etc.
+ for i := 1; i < len(certs); i++ {
+ if certs[i-1].CheckSignatureFrom(certs[i]) != nil {
+ h.error(alertBadCertificate);
+ return;
+ }
+ }
+
+ if config.RootCAs != nil {
+ root := config.RootCAs.FindParent(certs[len(certs)-1]);
+ if root == nil {
+ h.error(alertBadCertificate);
+ return;
+ }
+ if certs[len(certs)-1].CheckSignatureFrom(root) != nil {
+ h.error(alertBadCertificate);
+ return;
+ }
+ }
+
+ pub, ok := certs[0].PublicKey.(*rsa.PublicKey);
+ if !ok {
+ h.error(alertUnsupportedCertificate);
+ return;
+ }
+
+ shd, ok := h.readHandshakeMsg().(*serverHelloDoneMsg);
+ if !ok {
+ h.error(alertUnexpectedMessage);
+ return;
+ }
+ finishedHash.Write(shd.marshal());
+
+ ckx := new(clientKeyExchangeMsg);
+ preMasterSecret := make([]byte, 48);
+ // Note that the version number in the preMasterSecret must be the
+ // version offered in the ClientHello.
+ preMasterSecret[0] = defaultMajor;
+ preMasterSecret[1] = defaultMinor;
+ _, err = io.ReadFull(config.Rand, preMasterSecret[2:len(preMasterSecret)]);
+ if err != nil {
+ h.error(alertInternalError);
+ return;
+ }
+
+ ckx.ciphertext, err = rsa.EncryptPKCS1v15(config.Rand, pub, preMasterSecret);
+ if err != nil {
+ h.error(alertInternalError);
+ return;
+ }
+
+ finishedHash.Write(ckx.marshal());
+ writeChan <- ckx;
+
+ suite := cipherSuites[0];
+ masterSecret, clientMAC, serverMAC, clientKey, serverKey :=
+ keysFromPreMasterSecret11(preMasterSecret, hello.random, serverHello.random, suite.hashLength, suite.cipherKeyLength);
+
+ cipher, _ := rc4.NewCipher(clientKey);
+ writeChan <- writerChangeCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)};
+
+ finished := new(finishedMsg);
+ finished.verifyData = finishedHash.clientSum(masterSecret);
+ finishedHash.Write(finished.marshal());
+ writeChan <- finished;
+
+ // TODO(agl): this is cut-through mode which should probably be an option.
+ writeChan <- writerEnableApplicationData{};
+
+ _, ok = h.readHandshakeMsg().(changeCipherSpec);
+ if !ok {
+ h.error(alertUnexpectedMessage);
+ return;
+ }
+
+ cipher2, _ := rc4.NewCipher(serverKey);
+ controlChan <- &newCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)};
+
+ serverFinished, ok := h.readHandshakeMsg().(*finishedMsg);
+ if !ok {
+ h.error(alertUnexpectedMessage);
+ return;
+ }
+
+ verify := finishedHash.serverSum(masterSecret);
+ if len(verify) != len(serverFinished.verifyData) ||
+ subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
+ h.error(alertHandshakeFailure);
+ return;
+ }
+
+ controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0};
+
+ // This should just block forever.
+ _ = h.readHandshakeMsg();
+ h.error(alertUnexpectedMessage);
+ return;
+}
+
+func (h *clientHandshake) readHandshakeMsg() interface{} {
+ v := <-h.msgChan;
+ if closed(h.msgChan) {
+ // If the channel closed then the processor received an error
+ // from the peer and we don't want to echo it back to them.
+ h.msgChan = nil;
+ return 0;
+ }
+ if _, ok := v.(alert); ok {
+ // We got an alert from the processor. We forward to the writer
+ // and shutdown.
+ h.writeChan <- v;
+ h.msgChan = nil;
+ return 0;
+ }
+ return v;
+}
+
+func (h *clientHandshake) error(e alertType) {
+ if h.msgChan != nil {
+ // If we didn't get an error from the processor, then we need
+ // to tell it about the error.
+ go func() {
+ for _ = range h.msgChan {
+ }
+ }();
+ h.controlChan <- ConnectionState{false, "", e};
+ close(h.controlChan);
+ h.writeChan <- alert{alertLevelError, e};
+ }
+}
diff --git a/src/pkg/crypto/tls/handshake_messages.go b/src/pkg/crypto/tls/handshake_messages.go
index 87e2e779e..65dae8762 100644
--- a/src/pkg/crypto/tls/handshake_messages.go
+++ b/src/pkg/crypto/tls/handshake_messages.go
@@ -45,7 +45,7 @@ func (m *clientHelloMsg) marshal() []byte {
}
func (m *clientHelloMsg) unmarshal(data []byte) bool {
- if len(data) < 39 {
+ if len(data) < 43 {
return false
}
m.raw = data;
@@ -120,6 +120,30 @@ func (m *serverHelloMsg) marshal() []byte {
return x;
}
+func (m *serverHelloMsg) unmarshal(data []byte) bool {
+ if len(data) < 42 {
+ return false
+ }
+ m.raw = data;
+ m.major = data[4];
+ m.minor = data[5];
+ m.random = data[6:38];
+ sessionIdLen := int(data[38]);
+ if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
+ return false
+ }
+ m.sessionId = data[39 : 39+sessionIdLen];
+ data = data[39+sessionIdLen : len(data)];
+ if len(data) < 3 {
+ return false
+ }
+ m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]);
+ m.compressionMethod = data[2];
+
+ // Trailing data is allowed because extensions may be present.
+ return true;
+}
+
type certificateMsg struct {
raw []byte;
certificates [][]byte;
@@ -160,6 +184,43 @@ func (m *certificateMsg) marshal() (x []byte) {
return;
}
+func (m *certificateMsg) unmarshal(data []byte) bool {
+ if len(data) < 7 {
+ return false
+ }
+
+ m.raw = data;
+ certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]);
+ if uint32(len(data)) != certsLen+7 {
+ return false
+ }
+
+ numCerts := 0;
+ d := data[7:len(data)];
+ for certsLen > 0 {
+ if len(d) < 4 {
+ return false
+ }
+ certLen := uint32(d[0])<<24 | uint32(d[1])<<8 | uint32(d[2]);
+ if uint32(len(d)) < 3+certLen {
+ return false
+ }
+ d = d[3+certLen : len(d)];
+ certsLen -= 3 + certLen;
+ numCerts++;
+ }
+
+ m.certificates = make([][]byte, numCerts);
+ d = data[7:len(data)];
+ for i := 0; i < numCerts; i++ {
+ certLen := uint32(d[0])<<24 | uint32(d[1])<<8 | uint32(d[2]);
+ m.certificates[i] = d[3 : 3+certLen];
+ d = d[3+certLen : len(d)];
+ }
+
+ return true;
+}
+
type serverHelloDoneMsg struct{}
func (m *serverHelloDoneMsg) marshal() []byte {
@@ -168,6 +229,10 @@ func (m *serverHelloDoneMsg) marshal() []byte {
return x;
}
+func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
+ return len(data) == 4
+}
+
type clientKeyExchangeMsg struct {
raw []byte;
ciphertext []byte;
diff --git a/src/pkg/crypto/tls/handshake_messages_test.go b/src/pkg/crypto/tls/handshake_messages_test.go
index 5dafc388b..c580f65c6 100644
--- a/src/pkg/crypto/tls/handshake_messages_test.go
+++ b/src/pkg/crypto/tls/handshake_messages_test.go
@@ -13,6 +13,8 @@ import (
var tests = []interface{}{
&clientHelloMsg{},
+ &serverHelloMsg{},
+ &certificateMsg{},
&clientKeyExchangeMsg{},
&finishedMsg{},
}
@@ -59,6 +61,20 @@ func TestMarshalUnmarshal(t *testing.T) {
}
}
+func TestFuzz(t *testing.T) {
+ rand := rand.New(rand.NewSource(0));
+ for _, iface := range tests {
+ m := iface.(testMessage);
+
+ for j := 0; j < 1000; j++ {
+ len := rand.Intn(100);
+ bytes := randomBytes(len, rand);
+ // This just looks for crashes due to bounds errors etc.
+ m.unmarshal(bytes);
+ }
+ }
+}
+
func randomBytes(n int, rand *rand.Rand) []byte {
r := make([]byte, n);
for i := 0; i < n; i++ {
@@ -82,9 +98,30 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
return reflect.NewValue(m);
}
+func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
+ m := &serverHelloMsg{};
+ m.major = uint8(rand.Intn(256));
+ m.minor = uint8(rand.Intn(256));
+ m.random = randomBytes(32, rand);
+ m.sessionId = randomBytes(rand.Intn(32), rand);
+ m.cipherSuite = uint16(rand.Int31());
+ m.compressionMethod = uint8(rand.Intn(256));
+ return reflect.NewValue(m);
+}
+
+func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
+ m := &certificateMsg{};
+ numCerts := rand.Intn(20);
+ m.certificates = make([][]byte, numCerts);
+ for i := 0; i < numCerts; i++ {
+ m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
+ }
+ return reflect.NewValue(m);
+}
+
func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &clientKeyExchangeMsg{};
- m.ciphertext = randomBytes(rand.Intn(1000), rand);
+ m.ciphertext = randomBytes(rand.Intn(1000)+1, rand);
return reflect.NewValue(m);
}
diff --git a/src/pkg/crypto/tls/handshake_server.go b/src/pkg/crypto/tls/handshake_server.go
index 0e04c42af..2e7760365 100644
--- a/src/pkg/crypto/tls/handshake_server.go
+++ b/src/pkg/crypto/tls/handshake_server.go
@@ -224,12 +224,12 @@ func (h *serverHandshake) error(e alertType) {
if h.msgChan != nil {
// If we didn't get an error from the processor, then we need
// to tell it about the error.
- h.controlChan <- ConnectionState{false, "", e};
- close(h.controlChan);
go func() {
for _ = range h.msgChan {
}
}();
+ h.controlChan <- ConnectionState{false, "", e};
+ close(h.controlChan);
h.writeChan <- alert{alertLevelError, e};
}
}
diff --git a/src/pkg/crypto/tls/record_process.go b/src/pkg/crypto/tls/record_process.go
index b7edd9fd1..e356d67ca 100644
--- a/src/pkg/crypto/tls/record_process.go
+++ b/src/pkg/crypto/tls/record_process.go
@@ -210,7 +210,7 @@ func (p *recordProcessor) processRecord(r *record) {
return;
}
p.recordRead = nil;
- p.appData = r.payload;
+ p.appData = r.payload[0 : len(r.payload)-p.mac.Size()];
p.appDataSend = p.appDataChan;
default:
p.error(alertUnexpectedMessage);
@@ -283,6 +283,12 @@ func parseHandshakeMsg(data []byte) (interface{}, bool) {
switch data[0] {
case typeClientHello:
m = new(clientHelloMsg)
+ case typeServerHello:
+ m = new(serverHelloMsg)
+ case typeCertificate:
+ m = new(certificateMsg)
+ case typeServerHelloDone:
+ m = new(serverHelloDoneMsg)
case typeClientKeyExchange:
m = new(clientKeyExchangeMsg)
default:
diff --git a/src/pkg/crypto/tls/tls.go b/src/pkg/crypto/tls/tls.go
index a162487de..c5a0f69d5 100644
--- a/src/pkg/crypto/tls/tls.go
+++ b/src/pkg/crypto/tls/tls.go
@@ -112,9 +112,19 @@ func (tls *Conn) GetConnectionState() ConnectionState {
return <-replyChan;
}
+func (tls *Conn) WaitConnectionState() ConnectionState {
+ replyChan := make(chan ConnectionState);
+ tls.requestChan <- waitConnectionState{replyChan};
+ return <-replyChan;
+}
+
+type handshaker interface {
+ loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config);
+}
+
// Server establishes a secure connection over the given connection and acts
// as a TLS server.
-func Server(conn net.Conn, config *Config) *Conn {
+func startTLSGoroutines(conn net.Conn, h handshaker, config *Config) *Conn {
tls := new(Conn);
tls.Conn = conn;
@@ -134,11 +144,19 @@ func Server(conn net.Conn, config *Config) *Conn {
go new(recordWriter).loop(conn, writeChan, handshakeWriterChan);
go recordReader(readerProcessorChan, conn);
go new(recordProcessor).loop(readChan, requestChan, handshakeProcessorChan, readerProcessorChan, processorHandshakeChan);
- go new(serverHandshake).loop(handshakeWriterChan, handshakeProcessorChan, processorHandshakeChan, config);
+ go h.loop(handshakeWriterChan, handshakeProcessorChan, processorHandshakeChan, config);
return tls;
}
+func Server(conn net.Conn, config *Config) *Conn {
+ return startTLSGoroutines(conn, new(serverHandshake), config)
+}
+
+func Client(conn net.Conn, config *Config) *Conn {
+ return startTLSGoroutines(conn, new(clientHandshake), config)
+}
+
type Listener struct {
listener net.Listener;
config *Config;