diff options
Diffstat (limited to 'src/pkg/crypto/tls/handshake_messages.go')
-rw-r--r-- | src/pkg/crypto/tls/handshake_messages.go | 201 |
1 files changed, 182 insertions, 19 deletions
diff --git a/src/pkg/crypto/tls/handshake_messages.go b/src/pkg/crypto/tls/handshake_messages.go index f11232d8e..e1517cc79 100644 --- a/src/pkg/crypto/tls/handshake_messages.go +++ b/src/pkg/crypto/tls/handshake_messages.go @@ -4,6 +4,8 @@ package tls +import "bytes" + type clientHelloMsg struct { raw []byte vers uint16 @@ -18,6 +20,25 @@ type clientHelloMsg struct { supportedPoints []uint8 } +func (m *clientHelloMsg) equal(i interface{}) bool { + m1, ok := i.(*clientHelloMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.vers == m1.vers && + bytes.Equal(m.random, m1.random) && + bytes.Equal(m.sessionId, m1.sessionId) && + eqUint16s(m.cipherSuites, m1.cipherSuites) && + bytes.Equal(m.compressionMethods, m1.compressionMethods) && + m.nextProtoNeg == m1.nextProtoNeg && + m.serverName == m1.serverName && + m.ocspStapling == m1.ocspStapling && + eqUint16s(m.supportedCurves, m1.supportedCurves) && + bytes.Equal(m.supportedPoints, m1.supportedPoints) +} + func (m *clientHelloMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -309,6 +330,23 @@ type serverHelloMsg struct { ocspStapling bool } +func (m *serverHelloMsg) equal(i interface{}) bool { + m1, ok := i.(*serverHelloMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.vers == m1.vers && + bytes.Equal(m.random, m1.random) && + bytes.Equal(m.sessionId, m1.sessionId) && + m.cipherSuite == m1.cipherSuite && + m.compressionMethod == m1.compressionMethod && + m.nextProtoNeg == m1.nextProtoNeg && + eqStrings(m.nextProtos, m1.nextProtos) && + m.ocspStapling == m1.ocspStapling +} + func (m *serverHelloMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -463,6 +501,16 @@ type certificateMsg struct { certificates [][]byte } +func (m *certificateMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + eqByteSlices(m.certificates, m1.certificates) +} + func (m *certificateMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -540,6 +588,16 @@ type serverKeyExchangeMsg struct { key []byte } +func (m *serverKeyExchangeMsg) equal(i interface{}) bool { + m1, ok := i.(*serverKeyExchangeMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.key, m1.key) +} + func (m *serverKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -571,6 +629,17 @@ type certificateStatusMsg struct { response []byte } +func (m *certificateStatusMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateStatusMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.statusType == m1.statusType && + bytes.Equal(m.response, m1.response) +} + func (m *certificateStatusMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -622,6 +691,11 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool { type serverHelloDoneMsg struct{} +func (m *serverHelloDoneMsg) equal(i interface{}) bool { + _, ok := i.(*serverHelloDoneMsg) + return ok +} + func (m *serverHelloDoneMsg) marshal() []byte { x := make([]byte, 4) x[0] = typeServerHelloDone @@ -637,6 +711,16 @@ type clientKeyExchangeMsg struct { ciphertext []byte } +func (m *clientKeyExchangeMsg) equal(i interface{}) bool { + m1, ok := i.(*clientKeyExchangeMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.ciphertext, m1.ciphertext) +} + func (m *clientKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -671,6 +755,16 @@ type finishedMsg struct { verifyData []byte } +func (m *finishedMsg) equal(i interface{}) bool { + m1, ok := i.(*finishedMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.verifyData, m1.verifyData) +} + func (m *finishedMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -698,6 +792,16 @@ type nextProtoMsg struct { proto string } +func (m *nextProtoMsg) equal(i interface{}) bool { + m1, ok := i.(*nextProtoMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.proto == m1.proto +} + func (m *nextProtoMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -759,6 +863,17 @@ type certificateRequestMsg struct { certificateAuthorities [][]byte } +func (m *certificateRequestMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateRequestMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.certificateTypes, m1.certificateTypes) && + eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) +} + func (m *certificateRequestMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -766,9 +881,11 @@ func (m *certificateRequestMsg) marshal() (x []byte) { // See http://tools.ietf.org/html/rfc4346#section-7.4.4 length := 1 + len(m.certificateTypes) + 2 + casLength := 0 for _, ca := range m.certificateAuthorities { - length += 2 + len(ca) + casLength += 2 + len(ca) } + length += casLength x = make([]byte, 4+length) x[0] = typeCertificateRequest @@ -780,10 +897,8 @@ func (m *certificateRequestMsg) marshal() (x []byte) { copy(x[5:], m.certificateTypes) y := x[5+len(m.certificateTypes):] - - numCA := len(m.certificateAuthorities) - y[0] = uint8(numCA >> 8) - y[1] = uint8(numCA) + y[0] = uint8(casLength >> 8) + y[1] = uint8(casLength) y = y[2:] for _, ca := range m.certificateAuthorities { y[0] = uint8(len(ca) >> 8) @@ -794,7 +909,6 @@ func (m *certificateRequestMsg) marshal() (x []byte) { } m.raw = x - return } @@ -822,31 +936,34 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool { } data = data[numCertTypes:] + if len(data) < 2 { return false } - - numCAs := uint16(data[0])<<16 | uint16(data[1]) + casLength := uint16(data[0])<<8 | uint16(data[1]) data = data[2:] + if len(data) < int(casLength) { + return false + } + cas := make([]byte, casLength) + copy(cas, data) + data = data[casLength:] - m.certificateAuthorities = make([][]byte, numCAs) - for i := uint16(0); i < numCAs; i++ { - if len(data) < 2 { + m.certificateAuthorities = nil + for len(cas) > 0 { + if len(cas) < 2 { return false } - caLen := uint16(data[0])<<16 | uint16(data[1]) + caLen := uint16(cas[0])<<8 | uint16(cas[1]) + cas = cas[2:] - data = data[2:] - if len(data) < int(caLen) { + if len(cas) < int(caLen) { return false } - ca := make([]byte, caLen) - copy(ca, data) - m.certificateAuthorities[i] = ca - data = data[caLen:] + m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) + cas = cas[caLen:] } - if len(data) > 0 { return false } @@ -859,6 +976,16 @@ type certificateVerifyMsg struct { signature []byte } +func (m *certificateVerifyMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateVerifyMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.signature, m1.signature) +} + func (m *certificateVerifyMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -902,3 +1029,39 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool { return true } + +func eqUint16s(x, y []uint16) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if y[i] != v { + return false + } + } + return true +} + +func eqStrings(x, y []string) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if y[i] != v { + return false + } + } + return true +} + +func eqByteSlices(x, y [][]byte) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if !bytes.Equal(v, y[i]) { + return false + } + } + return true +} |