summaryrefslogtreecommitdiff
path: root/src/pkg/crypto/tls/handshake_messages.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/crypto/tls/handshake_messages.go')
-rw-r--r--src/pkg/crypto/tls/handshake_messages.go201
1 files changed, 182 insertions, 19 deletions
diff --git a/src/pkg/crypto/tls/handshake_messages.go b/src/pkg/crypto/tls/handshake_messages.go
index f11232d8e..e1517cc79 100644
--- a/src/pkg/crypto/tls/handshake_messages.go
+++ b/src/pkg/crypto/tls/handshake_messages.go
@@ -4,6 +4,8 @@
package tls
+import "bytes"
+
type clientHelloMsg struct {
raw []byte
vers uint16
@@ -18,6 +20,25 @@ type clientHelloMsg struct {
supportedPoints []uint8
}
+func (m *clientHelloMsg) equal(i interface{}) bool {
+ m1, ok := i.(*clientHelloMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ m.vers == m1.vers &&
+ bytes.Equal(m.random, m1.random) &&
+ bytes.Equal(m.sessionId, m1.sessionId) &&
+ eqUint16s(m.cipherSuites, m1.cipherSuites) &&
+ bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
+ m.nextProtoNeg == m1.nextProtoNeg &&
+ m.serverName == m1.serverName &&
+ m.ocspStapling == m1.ocspStapling &&
+ eqUint16s(m.supportedCurves, m1.supportedCurves) &&
+ bytes.Equal(m.supportedPoints, m1.supportedPoints)
+}
+
func (m *clientHelloMsg) marshal() []byte {
if m.raw != nil {
return m.raw
@@ -309,6 +330,23 @@ type serverHelloMsg struct {
ocspStapling bool
}
+func (m *serverHelloMsg) equal(i interface{}) bool {
+ m1, ok := i.(*serverHelloMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ m.vers == m1.vers &&
+ bytes.Equal(m.random, m1.random) &&
+ bytes.Equal(m.sessionId, m1.sessionId) &&
+ m.cipherSuite == m1.cipherSuite &&
+ m.compressionMethod == m1.compressionMethod &&
+ m.nextProtoNeg == m1.nextProtoNeg &&
+ eqStrings(m.nextProtos, m1.nextProtos) &&
+ m.ocspStapling == m1.ocspStapling
+}
+
func (m *serverHelloMsg) marshal() []byte {
if m.raw != nil {
return m.raw
@@ -463,6 +501,16 @@ type certificateMsg struct {
certificates [][]byte
}
+func (m *certificateMsg) equal(i interface{}) bool {
+ m1, ok := i.(*certificateMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ eqByteSlices(m.certificates, m1.certificates)
+}
+
func (m *certificateMsg) marshal() (x []byte) {
if m.raw != nil {
return m.raw
@@ -540,6 +588,16 @@ type serverKeyExchangeMsg struct {
key []byte
}
+func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
+ m1, ok := i.(*serverKeyExchangeMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.key, m1.key)
+}
+
func (m *serverKeyExchangeMsg) marshal() []byte {
if m.raw != nil {
return m.raw
@@ -571,6 +629,17 @@ type certificateStatusMsg struct {
response []byte
}
+func (m *certificateStatusMsg) equal(i interface{}) bool {
+ m1, ok := i.(*certificateStatusMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ m.statusType == m1.statusType &&
+ bytes.Equal(m.response, m1.response)
+}
+
func (m *certificateStatusMsg) marshal() []byte {
if m.raw != nil {
return m.raw
@@ -622,6 +691,11 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool {
type serverHelloDoneMsg struct{}
+func (m *serverHelloDoneMsg) equal(i interface{}) bool {
+ _, ok := i.(*serverHelloDoneMsg)
+ return ok
+}
+
func (m *serverHelloDoneMsg) marshal() []byte {
x := make([]byte, 4)
x[0] = typeServerHelloDone
@@ -637,6 +711,16 @@ type clientKeyExchangeMsg struct {
ciphertext []byte
}
+func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
+ m1, ok := i.(*clientKeyExchangeMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.ciphertext, m1.ciphertext)
+}
+
func (m *clientKeyExchangeMsg) marshal() []byte {
if m.raw != nil {
return m.raw
@@ -671,6 +755,16 @@ type finishedMsg struct {
verifyData []byte
}
+func (m *finishedMsg) equal(i interface{}) bool {
+ m1, ok := i.(*finishedMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.verifyData, m1.verifyData)
+}
+
func (m *finishedMsg) marshal() (x []byte) {
if m.raw != nil {
return m.raw
@@ -698,6 +792,16 @@ type nextProtoMsg struct {
proto string
}
+func (m *nextProtoMsg) equal(i interface{}) bool {
+ m1, ok := i.(*nextProtoMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ m.proto == m1.proto
+}
+
func (m *nextProtoMsg) marshal() []byte {
if m.raw != nil {
return m.raw
@@ -759,6 +863,17 @@ type certificateRequestMsg struct {
certificateAuthorities [][]byte
}
+func (m *certificateRequestMsg) equal(i interface{}) bool {
+ m1, ok := i.(*certificateRequestMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
+ eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities)
+}
+
func (m *certificateRequestMsg) marshal() (x []byte) {
if m.raw != nil {
return m.raw
@@ -766,9 +881,11 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
// See http://tools.ietf.org/html/rfc4346#section-7.4.4
length := 1 + len(m.certificateTypes) + 2
+ casLength := 0
for _, ca := range m.certificateAuthorities {
- length += 2 + len(ca)
+ casLength += 2 + len(ca)
}
+ length += casLength
x = make([]byte, 4+length)
x[0] = typeCertificateRequest
@@ -780,10 +897,8 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
copy(x[5:], m.certificateTypes)
y := x[5+len(m.certificateTypes):]
-
- numCA := len(m.certificateAuthorities)
- y[0] = uint8(numCA >> 8)
- y[1] = uint8(numCA)
+ y[0] = uint8(casLength >> 8)
+ y[1] = uint8(casLength)
y = y[2:]
for _, ca := range m.certificateAuthorities {
y[0] = uint8(len(ca) >> 8)
@@ -794,7 +909,6 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
}
m.raw = x
-
return
}
@@ -822,31 +936,34 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
}
data = data[numCertTypes:]
+
if len(data) < 2 {
return false
}
-
- numCAs := uint16(data[0])<<16 | uint16(data[1])
+ casLength := uint16(data[0])<<8 | uint16(data[1])
data = data[2:]
+ if len(data) < int(casLength) {
+ return false
+ }
+ cas := make([]byte, casLength)
+ copy(cas, data)
+ data = data[casLength:]
- m.certificateAuthorities = make([][]byte, numCAs)
- for i := uint16(0); i < numCAs; i++ {
- if len(data) < 2 {
+ m.certificateAuthorities = nil
+ for len(cas) > 0 {
+ if len(cas) < 2 {
return false
}
- caLen := uint16(data[0])<<16 | uint16(data[1])
+ caLen := uint16(cas[0])<<8 | uint16(cas[1])
+ cas = cas[2:]
- data = data[2:]
- if len(data) < int(caLen) {
+ if len(cas) < int(caLen) {
return false
}
- ca := make([]byte, caLen)
- copy(ca, data)
- m.certificateAuthorities[i] = ca
- data = data[caLen:]
+ m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
+ cas = cas[caLen:]
}
-
if len(data) > 0 {
return false
}
@@ -859,6 +976,16 @@ type certificateVerifyMsg struct {
signature []byte
}
+func (m *certificateVerifyMsg) equal(i interface{}) bool {
+ m1, ok := i.(*certificateVerifyMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.signature, m1.signature)
+}
+
func (m *certificateVerifyMsg) marshal() (x []byte) {
if m.raw != nil {
return m.raw
@@ -902,3 +1029,39 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
return true
}
+
+func eqUint16s(x, y []uint16) bool {
+ if len(x) != len(y) {
+ return false
+ }
+ for i, v := range x {
+ if y[i] != v {
+ return false
+ }
+ }
+ return true
+}
+
+func eqStrings(x, y []string) bool {
+ if len(x) != len(y) {
+ return false
+ }
+ for i, v := range x {
+ if y[i] != v {
+ return false
+ }
+ }
+ return true
+}
+
+func eqByteSlices(x, y [][]byte) bool {
+ if len(x) != len(y) {
+ return false
+ }
+ for i, v := range x {
+ if !bytes.Equal(v, y[i]) {
+ return false
+ }
+ }
+ return true
+}