summaryrefslogtreecommitdiff
path: root/src/pkg/exp/ssh/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/exp/ssh/server.go')
-rw-r--r--src/pkg/exp/ssh/server.go714
1 files changed, 714 insertions, 0 deletions
diff --git a/src/pkg/exp/ssh/server.go b/src/pkg/exp/ssh/server.go
new file mode 100644
index 000000000..bc0af13e8
--- /dev/null
+++ b/src/pkg/exp/ssh/server.go
@@ -0,0 +1,714 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "big"
+ "bufio"
+ "bytes"
+ "crypto"
+ "crypto/rand"
+ "crypto/rsa"
+ _ "crypto/sha1"
+ "crypto/x509"
+ "encoding/pem"
+ "net"
+ "os"
+ "sync"
+)
+
+var supportedKexAlgos = []string{kexAlgoDH14SHA1}
+var supportedHostKeyAlgos = []string{hostAlgoRSA}
+var supportedCiphers = []string{cipherAES128CTR}
+var supportedMACs = []string{macSHA196}
+var supportedCompressions = []string{compressionNone}
+
+// Server represents an SSH server. A Server may have several ServerConnections.
+type Server struct {
+ rsa *rsa.PrivateKey
+ rsaSerialized []byte
+
+ // NoClientAuth is true if clients are allowed to connect without
+ // authenticating.
+ NoClientAuth bool
+
+ // PasswordCallback, if non-nil, is called when a user attempts to
+ // authenticate using a password. It may be called concurrently from
+ // several goroutines.
+ PasswordCallback func(user, password string) bool
+
+ // PubKeyCallback, if non-nil, is called when a client attempts public
+ // key authentication. It must return true iff the given public key is
+ // valid for the given user.
+ PubKeyCallback func(user, algo string, pubkey []byte) bool
+}
+
+// SetRSAPrivateKey sets the private key for a Server. A Server must have a
+// private key configured in order to accept connections. The private key must
+// be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa"
+// typically contains such a key.
+func (s *Server) SetRSAPrivateKey(pemBytes []byte) os.Error {
+ block, _ := pem.Decode(pemBytes)
+ if block == nil {
+ return os.NewError("ssh: no key found")
+ }
+ var err os.Error
+ s.rsa, err = x509.ParsePKCS1PrivateKey(block.Bytes)
+ if err != nil {
+ return err
+ }
+
+ s.rsaSerialized = marshalRSA(s.rsa)
+ return nil
+}
+
+// marshalRSA serializes an RSA private key according to RFC 4256, section 6.6.
+func marshalRSA(priv *rsa.PrivateKey) []byte {
+ e := new(big.Int).SetInt64(int64(priv.E))
+ length := stringLength([]byte(hostAlgoRSA))
+ length += intLength(e)
+ length += intLength(priv.N)
+
+ ret := make([]byte, length)
+ r := marshalString(ret, []byte(hostAlgoRSA))
+ r = marshalInt(r, e)
+ r = marshalInt(r, priv.N)
+
+ return ret
+}
+
+// parseRSA parses an RSA key according to RFC 4256, section 6.6.
+func parseRSA(in []byte) (pubKey *rsa.PublicKey, ok bool) {
+ algo, in, ok := parseString(in)
+ if !ok || string(algo) != hostAlgoRSA {
+ return nil, false
+ }
+ bigE, in, ok := parseInt(in)
+ if !ok || bigE.BitLen() > 24 {
+ return nil, false
+ }
+ e := bigE.Int64()
+ if e < 3 || e&1 == 0 {
+ return nil, false
+ }
+ N, in, ok := parseInt(in)
+ if !ok || len(in) > 0 {
+ return nil, false
+ }
+ return &rsa.PublicKey{
+ N: N,
+ E: int(e),
+ }, true
+}
+
+func parseRSASig(in []byte) (sig []byte, ok bool) {
+ algo, in, ok := parseString(in)
+ if !ok || string(algo) != hostAlgoRSA {
+ return nil, false
+ }
+ sig, in, ok = parseString(in)
+ if len(in) > 0 {
+ ok = false
+ }
+ return
+}
+
+// cachedPubKey contains the results of querying whether a public key is
+// acceptable for a user. The cache only applies to a single ServerConnection.
+type cachedPubKey struct {
+ user, algo string
+ pubKey []byte
+ result bool
+}
+
+const maxCachedPubKeys = 16
+
+// ServerConnection represents an incomming connection to a Server.
+type ServerConnection struct {
+ Server *Server
+
+ *transport
+
+ channels map[uint32]*channel
+ nextChanId uint32
+
+ // lock protects err and also allows Channels to serialise their writes
+ // to out.
+ lock sync.RWMutex
+ err os.Error
+
+ // cachedPubKeys contains the cache results of tests for public keys.
+ // Since SSH clients will query whether a public key is acceptable
+ // before attempting to authenticate with it, we end up with duplicate
+ // queries for public key validity.
+ cachedPubKeys []cachedPubKey
+}
+
+// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
+type dhGroup struct {
+ g, p *big.Int
+}
+
+// dhGroup14 is the group called diffie-hellman-group14-sha1 in RFC 4253 and
+// Oakley Group 14 in RFC 3526.
+var dhGroup14 *dhGroup
+
+var dhGroup14Once sync.Once
+
+func initDHGroup14() {
+ p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
+
+ dhGroup14 = &dhGroup{
+ g: new(big.Int).SetInt64(2),
+ p: p,
+ }
+}
+
+type handshakeMagics struct {
+ clientVersion, serverVersion []byte
+ clientKexInit, serverKexInit []byte
+}
+
+// kexDH performs Diffie-Hellman key agreement on a ServerConnection. The
+// returned values are given the same names as in RFC 4253, section 8.
+func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) {
+ packet, err := s.readPacket()
+ if err != nil {
+ return
+ }
+ var kexDHInit kexDHInitMsg
+ if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil {
+ return
+ }
+
+ if kexDHInit.X.Sign() == 0 || kexDHInit.X.Cmp(group.p) >= 0 {
+ return nil, nil, os.NewError("client DH parameter out of bounds")
+ }
+
+ y, err := rand.Int(rand.Reader, group.p)
+ if err != nil {
+ return
+ }
+
+ Y := new(big.Int).Exp(group.g, y, group.p)
+ kInt := new(big.Int).Exp(kexDHInit.X, y, group.p)
+
+ var serializedHostKey []byte
+ switch hostKeyAlgo {
+ case hostAlgoRSA:
+ serializedHostKey = s.Server.rsaSerialized
+ default:
+ return nil, nil, os.NewError("internal error")
+ }
+
+ h := hashFunc.New()
+ writeString(h, magics.clientVersion)
+ writeString(h, magics.serverVersion)
+ writeString(h, magics.clientKexInit)
+ writeString(h, magics.serverKexInit)
+ writeString(h, serializedHostKey)
+ writeInt(h, kexDHInit.X)
+ writeInt(h, Y)
+ K = make([]byte, intLength(kInt))
+ marshalInt(K, kInt)
+ h.Write(K)
+
+ H = h.Sum()
+
+ h.Reset()
+ h.Write(H)
+ hh := h.Sum()
+
+ var sig []byte
+ switch hostKeyAlgo {
+ case hostAlgoRSA:
+ sig, err = rsa.SignPKCS1v15(rand.Reader, s.Server.rsa, hashFunc, hh)
+ if err != nil {
+ return
+ }
+ default:
+ return nil, nil, os.NewError("internal error")
+ }
+
+ serializedSig := serializeRSASignature(sig)
+
+ kexDHReply := kexDHReplyMsg{
+ HostKey: serializedHostKey,
+ Y: Y,
+ Signature: serializedSig,
+ }
+ packet = marshal(msgKexDHReply, kexDHReply)
+
+ err = s.writePacket(packet)
+ return
+}
+
+func serializeRSASignature(sig []byte) []byte {
+ length := stringLength([]byte(hostAlgoRSA))
+ length += stringLength(sig)
+
+ ret := make([]byte, length)
+ r := marshalString(ret, []byte(hostAlgoRSA))
+ r = marshalString(r, sig)
+
+ return ret
+}
+
+// serverVersion is the fixed identification string that Server will use.
+var serverVersion = []byte("SSH-2.0-Go\r\n")
+
+// buildDataSignedForAuth returns the data that is signed in order to prove
+// posession of a private key. See RFC 4252, section 7.
+func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
+ user := []byte(req.User)
+ service := []byte(req.Service)
+ method := []byte(req.Method)
+
+ length := stringLength(sessionId)
+ length += 1
+ length += stringLength(user)
+ length += stringLength(service)
+ length += stringLength(method)
+ length += 1
+ length += stringLength(algo)
+ length += stringLength(pubKey)
+
+ ret := make([]byte, length)
+ r := marshalString(ret, sessionId)
+ r[0] = msgUserAuthRequest
+ r = r[1:]
+ r = marshalString(r, user)
+ r = marshalString(r, service)
+ r = marshalString(r, method)
+ r[0] = 1
+ r = r[1:]
+ r = marshalString(r, algo)
+ r = marshalString(r, pubKey)
+ return ret
+}
+
+// Handshake performs an SSH transport and client authentication on the given ServerConnection.
+func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
+ var magics handshakeMagics
+ s.transport = &transport{
+ reader: reader{
+ Reader: bufio.NewReader(conn),
+ },
+ writer: writer{
+ Writer: bufio.NewWriter(conn),
+ rand: rand.Reader,
+ },
+ Close: func() os.Error {
+ return conn.Close()
+ },
+ }
+
+ if _, err := conn.Write(serverVersion); err != nil {
+ return err
+ }
+ magics.serverVersion = serverVersion[:len(serverVersion)-2]
+
+ version, ok := readVersion(s.transport)
+ if !ok {
+ return os.NewError("failed to read version string from client")
+ }
+ magics.clientVersion = version
+
+ serverKexInit := kexInitMsg{
+ KexAlgos: supportedKexAlgos,
+ ServerHostKeyAlgos: supportedHostKeyAlgos,
+ CiphersClientServer: supportedCiphers,
+ CiphersServerClient: supportedCiphers,
+ MACsClientServer: supportedMACs,
+ MACsServerClient: supportedMACs,
+ CompressionClientServer: supportedCompressions,
+ CompressionServerClient: supportedCompressions,
+ }
+ kexInitPacket := marshal(msgKexInit, serverKexInit)
+ magics.serverKexInit = kexInitPacket
+
+ if err := s.writePacket(kexInitPacket); err != nil {
+ return err
+ }
+
+ packet, err := s.readPacket()
+ if err != nil {
+ return err
+ }
+
+ magics.clientKexInit = packet
+
+ var clientKexInit kexInitMsg
+ if err = unmarshal(&clientKexInit, packet, msgKexInit); err != nil {
+ return err
+ }
+
+ kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(s.transport, s.transport, &clientKexInit, &serverKexInit)
+ if !ok {
+ return os.NewError("ssh: no common algorithms")
+ }
+
+ if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
+ // The client sent a Kex message for the wrong algorithm,
+ // which we have to ignore.
+ _, err := s.readPacket()
+ if err != nil {
+ return err
+ }
+ }
+
+ var H, K []byte
+ var hashFunc crypto.Hash
+ switch kexAlgo {
+ case kexAlgoDH14SHA1:
+ hashFunc = crypto.SHA1
+ dhGroup14Once.Do(initDHGroup14)
+ H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
+ default:
+ err = os.NewError("ssh: internal error")
+ }
+
+ if err != nil {
+ return err
+ }
+
+ packet = []byte{msgNewKeys}
+ if err = s.writePacket(packet); err != nil {
+ return err
+ }
+ if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
+ return err
+ }
+
+ if packet, err = s.readPacket(); err != nil {
+ return err
+ }
+ if packet[0] != msgNewKeys {
+ return UnexpectedMessageError{msgNewKeys, packet[0]}
+ }
+
+ s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc)
+
+ packet, err = s.readPacket()
+ if err != nil {
+ return err
+ }
+
+ var serviceRequest serviceRequestMsg
+ if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
+ return err
+ }
+ if serviceRequest.Service != serviceUserAuth {
+ return os.NewError("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
+ }
+
+ serviceAccept := serviceAcceptMsg{
+ Service: serviceUserAuth,
+ }
+ packet = marshal(msgServiceAccept, serviceAccept)
+ if err = s.writePacket(packet); err != nil {
+ return err
+ }
+
+ if err = s.authenticate(H); err != nil {
+ return err
+ }
+
+ s.channels = make(map[uint32]*channel)
+ return nil
+}
+
+func isAcceptableAlgo(algo string) bool {
+ return algo == hostAlgoRSA
+}
+
+// testPubKey returns true if the given public key is acceptable for the user.
+func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool {
+ if s.Server.PubKeyCallback == nil || !isAcceptableAlgo(algo) {
+ return false
+ }
+
+ for _, c := range s.cachedPubKeys {
+ if c.user == user && c.algo == algo && bytes.Equal(c.pubKey, pubKey) {
+ return c.result
+ }
+ }
+
+ result := s.Server.PubKeyCallback(user, algo, pubKey)
+ if len(s.cachedPubKeys) < maxCachedPubKeys {
+ c := cachedPubKey{
+ user: user,
+ algo: algo,
+ pubKey: make([]byte, len(pubKey)),
+ result: result,
+ }
+ copy(c.pubKey, pubKey)
+ s.cachedPubKeys = append(s.cachedPubKeys, c)
+ }
+
+ return result
+}
+
+func (s *ServerConnection) authenticate(H []byte) os.Error {
+ var userAuthReq userAuthRequestMsg
+ var err os.Error
+ var packet []byte
+
+userAuthLoop:
+ for {
+ if packet, err = s.readPacket(); err != nil {
+ return err
+ }
+ if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
+ return err
+ }
+
+ if userAuthReq.Service != serviceSSH {
+ return os.NewError("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
+ }
+
+ switch userAuthReq.Method {
+ case "none":
+ if s.Server.NoClientAuth {
+ break userAuthLoop
+ }
+ case "password":
+ if s.Server.PasswordCallback == nil {
+ break
+ }
+ payload := userAuthReq.Payload
+ if len(payload) < 1 || payload[0] != 0 {
+ return ParseError{msgUserAuthRequest}
+ }
+ payload = payload[1:]
+ password, payload, ok := parseString(payload)
+ if !ok || len(payload) > 0 {
+ return ParseError{msgUserAuthRequest}
+ }
+
+ if s.Server.PasswordCallback(userAuthReq.User, string(password)) {
+ break userAuthLoop
+ }
+ case "publickey":
+ if s.Server.PubKeyCallback == nil {
+ break
+ }
+ payload := userAuthReq.Payload
+ if len(payload) < 1 {
+ return ParseError{msgUserAuthRequest}
+ }
+ isQuery := payload[0] == 0
+ payload = payload[1:]
+ algoBytes, payload, ok := parseString(payload)
+ if !ok {
+ return ParseError{msgUserAuthRequest}
+ }
+ algo := string(algoBytes)
+
+ pubKey, payload, ok := parseString(payload)
+ if !ok {
+ return ParseError{msgUserAuthRequest}
+ }
+ if isQuery {
+ // The client can query if the given public key
+ // would be ok.
+ if len(payload) > 0 {
+ return ParseError{msgUserAuthRequest}
+ }
+ if s.testPubKey(userAuthReq.User, algo, pubKey) {
+ okMsg := userAuthPubKeyOkMsg{
+ Algo: algo,
+ PubKey: string(pubKey),
+ }
+ if err = s.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
+ return err
+ }
+ continue userAuthLoop
+ }
+ } else {
+ sig, payload, ok := parseString(payload)
+ if !ok || len(payload) > 0 {
+ return ParseError{msgUserAuthRequest}
+ }
+ if !isAcceptableAlgo(algo) {
+ break
+ }
+ rsaSig, ok := parseRSASig(sig)
+ if !ok {
+ return ParseError{msgUserAuthRequest}
+ }
+ signedData := buildDataSignedForAuth(H, userAuthReq, algoBytes, pubKey)
+ switch algo {
+ case hostAlgoRSA:
+ hashFunc := crypto.SHA1
+ h := hashFunc.New()
+ h.Write(signedData)
+ digest := h.Sum()
+ rsaKey, ok := parseRSA(pubKey)
+ if !ok {
+ return ParseError{msgUserAuthRequest}
+ }
+ if rsa.VerifyPKCS1v15(rsaKey, hashFunc, digest, rsaSig) != nil {
+ return ParseError{msgUserAuthRequest}
+ }
+ default:
+ return os.NewError("ssh: isAcceptableAlgo incorrect")
+ }
+ if s.testPubKey(userAuthReq.User, algo, pubKey) {
+ break userAuthLoop
+ }
+ }
+ }
+
+ var failureMsg userAuthFailureMsg
+ if s.Server.PasswordCallback != nil {
+ failureMsg.Methods = append(failureMsg.Methods, "password")
+ }
+ if s.Server.PubKeyCallback != nil {
+ failureMsg.Methods = append(failureMsg.Methods, "publickey")
+ }
+
+ if len(failureMsg.Methods) == 0 {
+ return os.NewError("ssh: no authentication methods configured but NoClientAuth is also false")
+ }
+
+ if err = s.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
+ return err
+ }
+ }
+
+ packet = []byte{msgUserAuthSuccess}
+ if err = s.writePacket(packet); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+const defaultWindowSize = 32768
+
+// Accept reads and processes messages on a ServerConnection. It must be called
+// in order to demultiplex messages to any resulting Channels.
+func (s *ServerConnection) Accept() (Channel, os.Error) {
+ if s.err != nil {
+ return nil, s.err
+ }
+
+ for {
+ packet, err := s.readPacket()
+ if err != nil {
+
+ s.lock.Lock()
+ s.err = err
+ s.lock.Unlock()
+
+ for _, c := range s.channels {
+ c.dead = true
+ c.handleData(nil)
+ }
+
+ return nil, err
+ }
+
+ switch packet[0] {
+ case msgChannelOpen:
+ var chanOpen channelOpenMsg
+ if err := unmarshal(&chanOpen, packet, msgChannelOpen); err != nil {
+ return nil, err
+ }
+
+ c := new(channel)
+ c.chanType = chanOpen.ChanType
+ c.theirId = chanOpen.PeersId
+ c.theirWindow = chanOpen.PeersWindow
+ c.maxPacketSize = chanOpen.MaxPacketSize
+ c.extraData = chanOpen.TypeSpecificData
+ c.myWindow = defaultWindowSize
+ c.serverConn = s
+ c.cond = sync.NewCond(&c.lock)
+ c.pendingData = make([]byte, c.myWindow)
+
+ s.lock.Lock()
+ c.myId = s.nextChanId
+ s.nextChanId++
+ s.channels[c.myId] = c
+ s.lock.Unlock()
+ return c, nil
+
+ case msgChannelRequest:
+ var chanRequest channelRequestMsg
+ if err := unmarshal(&chanRequest, packet, msgChannelRequest); err != nil {
+ return nil, err
+ }
+
+ s.lock.Lock()
+ c, ok := s.channels[chanRequest.PeersId]
+ if !ok {
+ continue
+ }
+ c.handlePacket(&chanRequest)
+ s.lock.Unlock()
+
+ case msgChannelData:
+ if len(packet) < 5 {
+ return nil, ParseError{msgChannelData}
+ }
+ chanId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
+
+ s.lock.Lock()
+ c, ok := s.channels[chanId]
+ if !ok {
+ continue
+ }
+ c.handleData(packet[9:])
+ s.lock.Unlock()
+
+ case msgChannelEOF:
+ var eofMsg channelEOFMsg
+ if err := unmarshal(&eofMsg, packet, msgChannelEOF); err != nil {
+ return nil, err
+ }
+
+ s.lock.Lock()
+ c, ok := s.channels[eofMsg.PeersId]
+ if !ok {
+ continue
+ }
+ c.handlePacket(&eofMsg)
+ s.lock.Unlock()
+
+ case msgChannelClose:
+ var closeMsg channelCloseMsg
+ if err := unmarshal(&closeMsg, packet, msgChannelClose); err != nil {
+ return nil, err
+ }
+
+ s.lock.Lock()
+ c, ok := s.channels[closeMsg.PeersId]
+ if !ok {
+ continue
+ }
+ c.handlePacket(&closeMsg)
+ s.lock.Unlock()
+
+ case msgGlobalRequest:
+ var request globalRequestMsg
+ if err := unmarshal(&request, packet, msgGlobalRequest); err != nil {
+ return nil, err
+ }
+
+ if request.WantReply {
+ if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
+ return nil, err
+ }
+ }
+
+ default:
+ // Unknown message. Ignore.
+ }
+ }
+
+ panic("unreachable")
+}