summaryrefslogtreecommitdiff
path: root/src/pkg/crypto/tls/handshake_messages.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/crypto/tls/handshake_messages.go')
-rw-r--r--src/pkg/crypto/tls/handshake_messages.go115
1 files changed, 83 insertions, 32 deletions
diff --git a/src/pkg/crypto/tls/handshake_messages.go b/src/pkg/crypto/tls/handshake_messages.go
index 83952000f..7bcaa5eb9 100644
--- a/src/pkg/crypto/tls/handshake_messages.go
+++ b/src/pkg/crypto/tls/handshake_messages.go
@@ -7,20 +7,21 @@ package tls
import "bytes"
type clientHelloMsg struct {
- raw []byte
- vers uint16
- random []byte
- sessionId []byte
- cipherSuites []uint16
- compressionMethods []uint8
- nextProtoNeg bool
- serverName string
- ocspStapling bool
- supportedCurves []uint16
- supportedPoints []uint8
- ticketSupported bool
- sessionTicket []uint8
- signatureAndHashes []signatureAndHash
+ raw []byte
+ vers uint16
+ random []byte
+ sessionId []byte
+ cipherSuites []uint16
+ compressionMethods []uint8
+ nextProtoNeg bool
+ serverName string
+ ocspStapling bool
+ supportedCurves []CurveID
+ supportedPoints []uint8
+ ticketSupported bool
+ sessionTicket []uint8
+ signatureAndHashes []signatureAndHash
+ secureRenegotiation bool
}
func (m *clientHelloMsg) equal(i interface{}) bool {
@@ -38,11 +39,12 @@ func (m *clientHelloMsg) equal(i interface{}) bool {
m.nextProtoNeg == m1.nextProtoNeg &&
m.serverName == m1.serverName &&
m.ocspStapling == m1.ocspStapling &&
- eqUint16s(m.supportedCurves, m1.supportedCurves) &&
+ eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
m.ticketSupported == m1.ticketSupported &&
bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
- eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes)
+ eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
+ m.secureRenegotiation == m1.secureRenegotiation
}
func (m *clientHelloMsg) marshal() []byte {
@@ -80,6 +82,10 @@ func (m *clientHelloMsg) marshal() []byte {
extensionsLength += 2 + 2*len(m.signatureAndHashes)
numExtensions++
}
+ if m.secureRenegotiation {
+ extensionsLength += 1
+ numExtensions++
+ }
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
length += 2 + extensionsLength
@@ -114,13 +120,13 @@ func (m *clientHelloMsg) marshal() []byte {
}
if m.nextProtoNeg {
z[0] = byte(extensionNextProtoNeg >> 8)
- z[1] = byte(extensionNextProtoNeg)
+ z[1] = byte(extensionNextProtoNeg & 0xff)
// The length is always 0
z = z[4:]
}
if len(m.serverName) > 0 {
z[0] = byte(extensionServerName >> 8)
- z[1] = byte(extensionServerName)
+ z[1] = byte(extensionServerName & 0xff)
l := len(m.serverName) + 5
z[2] = byte(l >> 8)
z[3] = byte(l)
@@ -224,6 +230,13 @@ func (m *clientHelloMsg) marshal() []byte {
z = z[2:]
}
}
+ if m.secureRenegotiation {
+ z[0] = byte(extensionRenegotiationInfo >> 8)
+ z[1] = byte(extensionRenegotiationInfo & 0xff)
+ z[2] = 0
+ z[3] = 1
+ z = z[5:]
+ }
m.raw = x
@@ -256,6 +269,9 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.cipherSuites = make([]uint16, numCipherSuites)
for i := 0; i < numCipherSuites; i++ {
m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
+ if m.cipherSuites[i] == scsvRenegotiation {
+ m.secureRenegotiation = true
+ }
}
data = data[2+cipherSuiteLen:]
if len(data) < 1 {
@@ -341,10 +357,10 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
return false
}
numCurves := l / 2
- m.supportedCurves = make([]uint16, numCurves)
+ m.supportedCurves = make([]CurveID, numCurves)
d := data[2:]
for i := 0; i < numCurves; i++ {
- m.supportedCurves[i] = uint16(d[0])<<8 | uint16(d[1])
+ m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
d = d[2:]
}
case extensionSupportedPoints:
@@ -379,6 +395,11 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.signatureAndHashes[i].signature = d[1]
d = d[2:]
}
+ case extensionRenegotiationInfo + 1:
+ if length != 1 || data[0] != 0 {
+ return false
+ }
+ m.secureRenegotiation = true
}
data = data[length:]
}
@@ -387,16 +408,17 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
}
type serverHelloMsg struct {
- raw []byte
- vers uint16
- random []byte
- sessionId []byte
- cipherSuite uint16
- compressionMethod uint8
- nextProtoNeg bool
- nextProtos []string
- ocspStapling bool
- ticketSupported bool
+ raw []byte
+ vers uint16
+ random []byte
+ sessionId []byte
+ cipherSuite uint16
+ compressionMethod uint8
+ nextProtoNeg bool
+ nextProtos []string
+ ocspStapling bool
+ ticketSupported bool
+ secureRenegotiation bool
}
func (m *serverHelloMsg) equal(i interface{}) bool {
@@ -414,7 +436,8 @@ func (m *serverHelloMsg) equal(i interface{}) bool {
m.nextProtoNeg == m1.nextProtoNeg &&
eqStrings(m.nextProtos, m1.nextProtos) &&
m.ocspStapling == m1.ocspStapling &&
- m.ticketSupported == m1.ticketSupported
+ m.ticketSupported == m1.ticketSupported &&
+ m.secureRenegotiation == m1.secureRenegotiation
}
func (m *serverHelloMsg) marshal() []byte {
@@ -441,6 +464,10 @@ func (m *serverHelloMsg) marshal() []byte {
if m.ticketSupported {
numExtensions++
}
+ if m.secureRenegotiation {
+ extensionsLength += 1
+ numExtensions++
+ }
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
length += 2 + extensionsLength
@@ -469,7 +496,7 @@ func (m *serverHelloMsg) marshal() []byte {
}
if m.nextProtoNeg {
z[0] = byte(extensionNextProtoNeg >> 8)
- z[1] = byte(extensionNextProtoNeg)
+ z[1] = byte(extensionNextProtoNeg & 0xff)
z[2] = byte(nextProtoLen >> 8)
z[3] = byte(nextProtoLen)
z = z[4:]
@@ -494,6 +521,13 @@ func (m *serverHelloMsg) marshal() []byte {
z[1] = byte(extensionSessionTicket)
z = z[4:]
}
+ if m.secureRenegotiation {
+ z[0] = byte(extensionRenegotiationInfo >> 8)
+ z[1] = byte(extensionRenegotiationInfo & 0xff)
+ z[2] = 0
+ z[3] = 1
+ z = z[5:]
+ }
m.raw = x
@@ -573,6 +607,11 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
return false
}
m.ticketSupported = true
+ case extensionRenegotiationInfo:
+ if length != 1 || data[0] != 0 {
+ return false
+ }
+ m.secureRenegotiation = true
}
data = data[length:]
}
@@ -1255,6 +1294,18 @@ func eqUint16s(x, y []uint16) bool {
return true
}
+func eqCurveIDs(x, y []CurveID) bool {
+ if len(x) != len(y) {
+ return false
+ }
+ for i, v := range x {
+ if y[i] != v {
+ return false
+ }
+ }
+ return true
+}
+
func eqStrings(x, y []string) bool {
if len(x) != len(y) {
return false