diff options
Diffstat (limited to 'src/pkg/crypto/tls')
-rw-r--r-- | src/pkg/crypto/tls/common.go | 14 | ||||
-rw-r--r-- | src/pkg/crypto/tls/conn.go | 14 | ||||
-rw-r--r-- | src/pkg/crypto/tls/generate_cert.go | 10 | ||||
-rw-r--r-- | src/pkg/crypto/tls/handshake_client.go | 33 | ||||
-rw-r--r-- | src/pkg/crypto/tls/handshake_client_test.go | 2 | ||||
-rw-r--r-- | src/pkg/crypto/tls/handshake_messages_test.go | 8 | ||||
-rw-r--r-- | src/pkg/crypto/tls/tls.go | 19 |
7 files changed, 70 insertions, 30 deletions
diff --git a/src/pkg/crypto/tls/common.go b/src/pkg/crypto/tls/common.go index 7135f3d0f..fb2916ae0 100644 --- a/src/pkg/crypto/tls/common.go +++ b/src/pkg/crypto/tls/common.go @@ -7,6 +7,7 @@ package tls import ( "crypto/rand" "crypto/rsa" + "crypto/x509" "io" "io/ioutil" "sync" @@ -92,9 +93,13 @@ const ( // ConnectionState records basic TLS details about the connection. type ConnectionState struct { - HandshakeComplete bool - CipherSuite uint16 - NegotiatedProtocol string + HandshakeComplete bool + CipherSuite uint16 + NegotiatedProtocol string + NegotiatedProtocolIsMutual bool + + // the certificate chain that was presented by the other side + PeerCertificates []*x509.Certificate } // A Config structure is used to configure a TLS client or server. After one @@ -120,7 +125,6 @@ type Config struct { RootCAs *CASet // NextProtos is a list of supported, application level protocols. - // Currently only server-side handling is supported. NextProtos []string // ServerName is included in the client's handshake to support virtual @@ -251,7 +255,7 @@ var varDefaultCipherSuites []uint16 func initDefaultCipherSuites() { varDefaultCipherSuites = make([]uint16, len(cipherSuites)) i := 0 - for id, _ := range cipherSuites { + for id := range cipherSuites { varDefaultCipherSuites[i] = id i++ } diff --git a/src/pkg/crypto/tls/conn.go b/src/pkg/crypto/tls/conn.go index d203e8d51..b94e235c8 100644 --- a/src/pkg/crypto/tls/conn.go +++ b/src/pkg/crypto/tls/conn.go @@ -35,7 +35,8 @@ type Conn struct { ocspResponse []byte // stapled OCSP response peerCertificates []*x509.Certificate - clientProtocol string + clientProtocol string + clientProtocolFallback bool // first permanent error errMutex sync.Mutex @@ -761,7 +762,9 @@ func (c *Conn) ConnectionState() ConnectionState { state.HandshakeComplete = c.handshakeComplete if c.handshakeComplete { state.NegotiatedProtocol = c.clientProtocol + state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback state.CipherSuite = c.cipherSuite + state.PeerCertificates = c.peerCertificates } return state @@ -776,15 +779,6 @@ func (c *Conn) OCSPResponse() []byte { return c.ocspResponse } -// PeerCertificates returns the certificate chain that was presented by the -// other side. -func (c *Conn) PeerCertificates() []*x509.Certificate { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - return c.peerCertificates -} - // VerifyHostname checks that the peer certificate chain is valid for // connecting to host. If so, it returns nil; if not, it returns an os.Error // describing the problem. diff --git a/src/pkg/crypto/tls/generate_cert.go b/src/pkg/crypto/tls/generate_cert.go index 3e0c63938..5b8c700e5 100644 --- a/src/pkg/crypto/tls/generate_cert.go +++ b/src/pkg/crypto/tls/generate_cert.go @@ -25,7 +25,7 @@ func main() { priv, err := rsa.GenerateKey(rand.Reader, 1024) if err != nil { - log.Exitf("failed to generate private key: %s", err) + log.Fatalf("failed to generate private key: %s", err) return } @@ -46,20 +46,20 @@ func main() { derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { - log.Exitf("Failed to create certificate: %s", err) + log.Fatalf("Failed to create certificate: %s", err) return } - certOut, err := os.Open("cert.pem", os.O_WRONLY|os.O_CREAT, 0644) + certOut, err := os.Create("cert.pem") if err != nil { - log.Exitf("failed to open cert.pem for writing: %s", err) + log.Fatalf("failed to open cert.pem for writing: %s", err) return } pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) certOut.Close() log.Print("written cert.pem\n") - keyOut, err := os.Open("key.pem", os.O_WRONLY|os.O_CREAT, 0600) + keyOut, err := os.OpenFile("key.pem", os.O_WRONLY|os.O_CREAT|os.O_TRUNC, 0600) if err != nil { log.Print("failed to open key.pem for writing:", err) return diff --git a/src/pkg/crypto/tls/handshake_client.go b/src/pkg/crypto/tls/handshake_client.go index a325a9b95..540b25c87 100644 --- a/src/pkg/crypto/tls/handshake_client.go +++ b/src/pkg/crypto/tls/handshake_client.go @@ -29,6 +29,7 @@ func (c *Conn) clientHandshake() os.Error { serverName: c.config.ServerName, supportedCurves: []uint16{curveP256, curveP384, curveP521}, supportedPoints: []uint8{pointFormatUncompressed}, + nextProtoNeg: len(c.config.NextProtos) > 0, } t := uint32(c.config.time()) @@ -66,6 +67,11 @@ func (c *Conn) clientHandshake() os.Error { return c.sendAlert(alertUnexpectedMessage) } + if !hello.nextProtoNeg && serverHello.nextProtoNeg { + c.sendAlert(alertHandshakeFailure) + return os.ErrorString("server advertised unrequested NPN") + } + suite, suiteId := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite) if suite == nil { return c.sendAlert(alertHandshakeFailure) @@ -267,6 +273,17 @@ func (c *Conn) clientHandshake() os.Error { c.out.prepareCipherSpec(clientCipher, clientHash) c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + if serverHello.nextProtoNeg { + nextProto := new(nextProtoMsg) + proto, fallback := mutualProtocol(c.config.NextProtos, serverHello.nextProtos) + nextProto.proto = proto + c.clientProtocol = proto + c.clientProtocolFallback = fallback + + finishedHash.Write(nextProto.marshal()) + c.writeRecord(recordTypeHandshake, nextProto.marshal()) + } + finished := new(finishedMsg) finished.verifyData = finishedHash.clientSum(masterSecret) finishedHash.Write(finished.marshal()) @@ -299,3 +316,19 @@ func (c *Conn) clientHandshake() os.Error { c.cipherSuite = suiteId return nil } + +// mutualProtocol finds the mutual Next Protocol Negotiation protocol given the +// set of client and server supported protocols. The set of client supported +// protocols must not be empty. It returns the resulting protocol and flag +// indicating if the fallback case was reached. +func mutualProtocol(clientProtos, serverProtos []string) (string, bool) { + for _, s := range serverProtos { + for _, c := range clientProtos { + if s == c { + return s, false + } + } + } + + return clientProtos[0], true +} diff --git a/src/pkg/crypto/tls/handshake_client_test.go b/src/pkg/crypto/tls/handshake_client_test.go index fd1f145cf..3f91c7acf 100644 --- a/src/pkg/crypto/tls/handshake_client_test.go +++ b/src/pkg/crypto/tls/handshake_client_test.go @@ -50,7 +50,7 @@ func TestRunClient(t *testing.T) { testConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA} - conn, err := Dial("tcp", "", "127.0.0.1:10443", testConfig) + conn, err := Dial("tcp", "127.0.0.1:10443", testConfig) if err != nil { t.Fatal(err) } diff --git a/src/pkg/crypto/tls/handshake_messages_test.go b/src/pkg/crypto/tls/handshake_messages_test.go index 21577dd0b..f5e94e269 100644 --- a/src/pkg/crypto/tls/handshake_messages_test.go +++ b/src/pkg/crypto/tls/handshake_messages_test.go @@ -34,7 +34,11 @@ func TestMarshalUnmarshal(t *testing.T) { for i, iface := range tests { ty := reflect.NewValue(iface).Type() - for j := 0; j < 100; j++ { + n := 100 + if testing.Short() { + n = 5 + } + for j := 0; j < n; j++ { v, ok := quick.Value(ty, rand) if !ok { t.Errorf("#%d: failed to create value", i) @@ -117,7 +121,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { m.ocspStapling = rand.Intn(10) > 5 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) m.supportedCurves = make([]uint16, rand.Intn(5)+1) - for i, _ := range m.supportedCurves { + for i := range m.supportedCurves { m.supportedCurves[i] = uint16(rand.Intn(30000)) } diff --git a/src/pkg/crypto/tls/tls.go b/src/pkg/crypto/tls/tls.go index e8290d728..7de44bbd2 100644 --- a/src/pkg/crypto/tls/tls.go +++ b/src/pkg/crypto/tls/tls.go @@ -87,8 +87,9 @@ func Listen(network, laddr string, config *Config) (*Listener, os.Error) { // Dial interprets a nil configuration as equivalent to // the zero configuration; see the documentation of Config // for the defaults. -func Dial(network, laddr, raddr string, config *Config) (*Conn, os.Error) { - c, err := net.Dial(network, laddr, raddr) +func Dial(network, addr string, config *Config) (*Conn, os.Error) { + raddr := addr + c, err := net.Dial(network, raddr) if err != nil { return nil, err } @@ -123,7 +124,16 @@ func LoadX509KeyPair(certFile string, keyFile string) (cert Certificate, err os. if err != nil { return } + keyPEMBlock, err := ioutil.ReadFile(keyFile) + if err != nil { + return + } + return X509KeyPair(certPEMBlock, keyPEMBlock) +} +// X509KeyPair parses a public/private key pair from a pair of +// PEM encoded data. +func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (cert Certificate, err os.Error) { var certDERBlock *pem.Block for { certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) @@ -140,11 +150,6 @@ func LoadX509KeyPair(certFile string, keyFile string) (cert Certificate, err os. return } - keyPEMBlock, err := ioutil.ReadFile(keyFile) - if err != nil { - return - } - keyDERBlock, _ := pem.Decode(keyPEMBlock) if keyDERBlock == nil { err = os.ErrorString("crypto/tls: failed to parse key PEM data") |