diff options
Diffstat (limited to 'src/pkg/exp/ssh/server.go')
-rw-r--r-- | src/pkg/exp/ssh/server.go | 714 |
1 files changed, 714 insertions, 0 deletions
diff --git a/src/pkg/exp/ssh/server.go b/src/pkg/exp/ssh/server.go new file mode 100644 index 000000000..bc0af13e8 --- /dev/null +++ b/src/pkg/exp/ssh/server.go @@ -0,0 +1,714 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "big" + "bufio" + "bytes" + "crypto" + "crypto/rand" + "crypto/rsa" + _ "crypto/sha1" + "crypto/x509" + "encoding/pem" + "net" + "os" + "sync" +) + +var supportedKexAlgos = []string{kexAlgoDH14SHA1} +var supportedHostKeyAlgos = []string{hostAlgoRSA} +var supportedCiphers = []string{cipherAES128CTR} +var supportedMACs = []string{macSHA196} +var supportedCompressions = []string{compressionNone} + +// Server represents an SSH server. A Server may have several ServerConnections. +type Server struct { + rsa *rsa.PrivateKey + rsaSerialized []byte + + // NoClientAuth is true if clients are allowed to connect without + // authenticating. + NoClientAuth bool + + // PasswordCallback, if non-nil, is called when a user attempts to + // authenticate using a password. It may be called concurrently from + // several goroutines. + PasswordCallback func(user, password string) bool + + // PubKeyCallback, if non-nil, is called when a client attempts public + // key authentication. It must return true iff the given public key is + // valid for the given user. + PubKeyCallback func(user, algo string, pubkey []byte) bool +} + +// SetRSAPrivateKey sets the private key for a Server. A Server must have a +// private key configured in order to accept connections. The private key must +// be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa" +// typically contains such a key. +func (s *Server) SetRSAPrivateKey(pemBytes []byte) os.Error { + block, _ := pem.Decode(pemBytes) + if block == nil { + return os.NewError("ssh: no key found") + } + var err os.Error + s.rsa, err = x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return err + } + + s.rsaSerialized = marshalRSA(s.rsa) + return nil +} + +// marshalRSA serializes an RSA private key according to RFC 4256, section 6.6. +func marshalRSA(priv *rsa.PrivateKey) []byte { + e := new(big.Int).SetInt64(int64(priv.E)) + length := stringLength([]byte(hostAlgoRSA)) + length += intLength(e) + length += intLength(priv.N) + + ret := make([]byte, length) + r := marshalString(ret, []byte(hostAlgoRSA)) + r = marshalInt(r, e) + r = marshalInt(r, priv.N) + + return ret +} + +// parseRSA parses an RSA key according to RFC 4256, section 6.6. +func parseRSA(in []byte) (pubKey *rsa.PublicKey, ok bool) { + algo, in, ok := parseString(in) + if !ok || string(algo) != hostAlgoRSA { + return nil, false + } + bigE, in, ok := parseInt(in) + if !ok || bigE.BitLen() > 24 { + return nil, false + } + e := bigE.Int64() + if e < 3 || e&1 == 0 { + return nil, false + } + N, in, ok := parseInt(in) + if !ok || len(in) > 0 { + return nil, false + } + return &rsa.PublicKey{ + N: N, + E: int(e), + }, true +} + +func parseRSASig(in []byte) (sig []byte, ok bool) { + algo, in, ok := parseString(in) + if !ok || string(algo) != hostAlgoRSA { + return nil, false + } + sig, in, ok = parseString(in) + if len(in) > 0 { + ok = false + } + return +} + +// cachedPubKey contains the results of querying whether a public key is +// acceptable for a user. The cache only applies to a single ServerConnection. +type cachedPubKey struct { + user, algo string + pubKey []byte + result bool +} + +const maxCachedPubKeys = 16 + +// ServerConnection represents an incomming connection to a Server. +type ServerConnection struct { + Server *Server + + *transport + + channels map[uint32]*channel + nextChanId uint32 + + // lock protects err and also allows Channels to serialise their writes + // to out. + lock sync.RWMutex + err os.Error + + // cachedPubKeys contains the cache results of tests for public keys. + // Since SSH clients will query whether a public key is acceptable + // before attempting to authenticate with it, we end up with duplicate + // queries for public key validity. + cachedPubKeys []cachedPubKey +} + +// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement. +type dhGroup struct { + g, p *big.Int +} + +// dhGroup14 is the group called diffie-hellman-group14-sha1 in RFC 4253 and +// Oakley Group 14 in RFC 3526. +var dhGroup14 *dhGroup + +var dhGroup14Once sync.Once + +func initDHGroup14() { + p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + + dhGroup14 = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + } +} + +type handshakeMagics struct { + clientVersion, serverVersion []byte + clientKexInit, serverKexInit []byte +} + +// kexDH performs Diffie-Hellman key agreement on a ServerConnection. The +// returned values are given the same names as in RFC 4253, section 8. +func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) { + packet, err := s.readPacket() + if err != nil { + return + } + var kexDHInit kexDHInitMsg + if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil { + return + } + + if kexDHInit.X.Sign() == 0 || kexDHInit.X.Cmp(group.p) >= 0 { + return nil, nil, os.NewError("client DH parameter out of bounds") + } + + y, err := rand.Int(rand.Reader, group.p) + if err != nil { + return + } + + Y := new(big.Int).Exp(group.g, y, group.p) + kInt := new(big.Int).Exp(kexDHInit.X, y, group.p) + + var serializedHostKey []byte + switch hostKeyAlgo { + case hostAlgoRSA: + serializedHostKey = s.Server.rsaSerialized + default: + return nil, nil, os.NewError("internal error") + } + + h := hashFunc.New() + writeString(h, magics.clientVersion) + writeString(h, magics.serverVersion) + writeString(h, magics.clientKexInit) + writeString(h, magics.serverKexInit) + writeString(h, serializedHostKey) + writeInt(h, kexDHInit.X) + writeInt(h, Y) + K = make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + H = h.Sum() + + h.Reset() + h.Write(H) + hh := h.Sum() + + var sig []byte + switch hostKeyAlgo { + case hostAlgoRSA: + sig, err = rsa.SignPKCS1v15(rand.Reader, s.Server.rsa, hashFunc, hh) + if err != nil { + return + } + default: + return nil, nil, os.NewError("internal error") + } + + serializedSig := serializeRSASignature(sig) + + kexDHReply := kexDHReplyMsg{ + HostKey: serializedHostKey, + Y: Y, + Signature: serializedSig, + } + packet = marshal(msgKexDHReply, kexDHReply) + + err = s.writePacket(packet) + return +} + +func serializeRSASignature(sig []byte) []byte { + length := stringLength([]byte(hostAlgoRSA)) + length += stringLength(sig) + + ret := make([]byte, length) + r := marshalString(ret, []byte(hostAlgoRSA)) + r = marshalString(r, sig) + + return ret +} + +// serverVersion is the fixed identification string that Server will use. +var serverVersion = []byte("SSH-2.0-Go\r\n") + +// buildDataSignedForAuth returns the data that is signed in order to prove +// posession of a private key. See RFC 4252, section 7. +func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { + user := []byte(req.User) + service := []byte(req.Service) + method := []byte(req.Method) + + length := stringLength(sessionId) + length += 1 + length += stringLength(user) + length += stringLength(service) + length += stringLength(method) + length += 1 + length += stringLength(algo) + length += stringLength(pubKey) + + ret := make([]byte, length) + r := marshalString(ret, sessionId) + r[0] = msgUserAuthRequest + r = r[1:] + r = marshalString(r, user) + r = marshalString(r, service) + r = marshalString(r, method) + r[0] = 1 + r = r[1:] + r = marshalString(r, algo) + r = marshalString(r, pubKey) + return ret +} + +// Handshake performs an SSH transport and client authentication on the given ServerConnection. +func (s *ServerConnection) Handshake(conn net.Conn) os.Error { + var magics handshakeMagics + s.transport = &transport{ + reader: reader{ + Reader: bufio.NewReader(conn), + }, + writer: writer{ + Writer: bufio.NewWriter(conn), + rand: rand.Reader, + }, + Close: func() os.Error { + return conn.Close() + }, + } + + if _, err := conn.Write(serverVersion); err != nil { + return err + } + magics.serverVersion = serverVersion[:len(serverVersion)-2] + + version, ok := readVersion(s.transport) + if !ok { + return os.NewError("failed to read version string from client") + } + magics.clientVersion = version + + serverKexInit := kexInitMsg{ + KexAlgos: supportedKexAlgos, + ServerHostKeyAlgos: supportedHostKeyAlgos, + CiphersClientServer: supportedCiphers, + CiphersServerClient: supportedCiphers, + MACsClientServer: supportedMACs, + MACsServerClient: supportedMACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + } + kexInitPacket := marshal(msgKexInit, serverKexInit) + magics.serverKexInit = kexInitPacket + + if err := s.writePacket(kexInitPacket); err != nil { + return err + } + + packet, err := s.readPacket() + if err != nil { + return err + } + + magics.clientKexInit = packet + + var clientKexInit kexInitMsg + if err = unmarshal(&clientKexInit, packet, msgKexInit); err != nil { + return err + } + + kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(s.transport, s.transport, &clientKexInit, &serverKexInit) + if !ok { + return os.NewError("ssh: no common algorithms") + } + + if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] { + // The client sent a Kex message for the wrong algorithm, + // which we have to ignore. + _, err := s.readPacket() + if err != nil { + return err + } + } + + var H, K []byte + var hashFunc crypto.Hash + switch kexAlgo { + case kexAlgoDH14SHA1: + hashFunc = crypto.SHA1 + dhGroup14Once.Do(initDHGroup14) + H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo) + default: + err = os.NewError("ssh: internal error") + } + + if err != nil { + return err + } + + packet = []byte{msgNewKeys} + if err = s.writePacket(packet); err != nil { + return err + } + if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil { + return err + } + + if packet, err = s.readPacket(); err != nil { + return err + } + if packet[0] != msgNewKeys { + return UnexpectedMessageError{msgNewKeys, packet[0]} + } + + s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc) + + packet, err = s.readPacket() + if err != nil { + return err + } + + var serviceRequest serviceRequestMsg + if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil { + return err + } + if serviceRequest.Service != serviceUserAuth { + return os.NewError("ssh: requested service '" + serviceRequest.Service + "' before authenticating") + } + + serviceAccept := serviceAcceptMsg{ + Service: serviceUserAuth, + } + packet = marshal(msgServiceAccept, serviceAccept) + if err = s.writePacket(packet); err != nil { + return err + } + + if err = s.authenticate(H); err != nil { + return err + } + + s.channels = make(map[uint32]*channel) + return nil +} + +func isAcceptableAlgo(algo string) bool { + return algo == hostAlgoRSA +} + +// testPubKey returns true if the given public key is acceptable for the user. +func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool { + if s.Server.PubKeyCallback == nil || !isAcceptableAlgo(algo) { + return false + } + + for _, c := range s.cachedPubKeys { + if c.user == user && c.algo == algo && bytes.Equal(c.pubKey, pubKey) { + return c.result + } + } + + result := s.Server.PubKeyCallback(user, algo, pubKey) + if len(s.cachedPubKeys) < maxCachedPubKeys { + c := cachedPubKey{ + user: user, + algo: algo, + pubKey: make([]byte, len(pubKey)), + result: result, + } + copy(c.pubKey, pubKey) + s.cachedPubKeys = append(s.cachedPubKeys, c) + } + + return result +} + +func (s *ServerConnection) authenticate(H []byte) os.Error { + var userAuthReq userAuthRequestMsg + var err os.Error + var packet []byte + +userAuthLoop: + for { + if packet, err = s.readPacket(); err != nil { + return err + } + if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil { + return err + } + + if userAuthReq.Service != serviceSSH { + return os.NewError("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) + } + + switch userAuthReq.Method { + case "none": + if s.Server.NoClientAuth { + break userAuthLoop + } + case "password": + if s.Server.PasswordCallback == nil { + break + } + payload := userAuthReq.Payload + if len(payload) < 1 || payload[0] != 0 { + return ParseError{msgUserAuthRequest} + } + payload = payload[1:] + password, payload, ok := parseString(payload) + if !ok || len(payload) > 0 { + return ParseError{msgUserAuthRequest} + } + + if s.Server.PasswordCallback(userAuthReq.User, string(password)) { + break userAuthLoop + } + case "publickey": + if s.Server.PubKeyCallback == nil { + break + } + payload := userAuthReq.Payload + if len(payload) < 1 { + return ParseError{msgUserAuthRequest} + } + isQuery := payload[0] == 0 + payload = payload[1:] + algoBytes, payload, ok := parseString(payload) + if !ok { + return ParseError{msgUserAuthRequest} + } + algo := string(algoBytes) + + pubKey, payload, ok := parseString(payload) + if !ok { + return ParseError{msgUserAuthRequest} + } + if isQuery { + // The client can query if the given public key + // would be ok. + if len(payload) > 0 { + return ParseError{msgUserAuthRequest} + } + if s.testPubKey(userAuthReq.User, algo, pubKey) { + okMsg := userAuthPubKeyOkMsg{ + Algo: algo, + PubKey: string(pubKey), + } + if err = s.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil { + return err + } + continue userAuthLoop + } + } else { + sig, payload, ok := parseString(payload) + if !ok || len(payload) > 0 { + return ParseError{msgUserAuthRequest} + } + if !isAcceptableAlgo(algo) { + break + } + rsaSig, ok := parseRSASig(sig) + if !ok { + return ParseError{msgUserAuthRequest} + } + signedData := buildDataSignedForAuth(H, userAuthReq, algoBytes, pubKey) + switch algo { + case hostAlgoRSA: + hashFunc := crypto.SHA1 + h := hashFunc.New() + h.Write(signedData) + digest := h.Sum() + rsaKey, ok := parseRSA(pubKey) + if !ok { + return ParseError{msgUserAuthRequest} + } + if rsa.VerifyPKCS1v15(rsaKey, hashFunc, digest, rsaSig) != nil { + return ParseError{msgUserAuthRequest} + } + default: + return os.NewError("ssh: isAcceptableAlgo incorrect") + } + if s.testPubKey(userAuthReq.User, algo, pubKey) { + break userAuthLoop + } + } + } + + var failureMsg userAuthFailureMsg + if s.Server.PasswordCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "password") + } + if s.Server.PubKeyCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "publickey") + } + + if len(failureMsg.Methods) == 0 { + return os.NewError("ssh: no authentication methods configured but NoClientAuth is also false") + } + + if err = s.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil { + return err + } + } + + packet = []byte{msgUserAuthSuccess} + if err = s.writePacket(packet); err != nil { + return err + } + + return nil +} + +const defaultWindowSize = 32768 + +// Accept reads and processes messages on a ServerConnection. It must be called +// in order to demultiplex messages to any resulting Channels. +func (s *ServerConnection) Accept() (Channel, os.Error) { + if s.err != nil { + return nil, s.err + } + + for { + packet, err := s.readPacket() + if err != nil { + + s.lock.Lock() + s.err = err + s.lock.Unlock() + + for _, c := range s.channels { + c.dead = true + c.handleData(nil) + } + + return nil, err + } + + switch packet[0] { + case msgChannelOpen: + var chanOpen channelOpenMsg + if err := unmarshal(&chanOpen, packet, msgChannelOpen); err != nil { + return nil, err + } + + c := new(channel) + c.chanType = chanOpen.ChanType + c.theirId = chanOpen.PeersId + c.theirWindow = chanOpen.PeersWindow + c.maxPacketSize = chanOpen.MaxPacketSize + c.extraData = chanOpen.TypeSpecificData + c.myWindow = defaultWindowSize + c.serverConn = s + c.cond = sync.NewCond(&c.lock) + c.pendingData = make([]byte, c.myWindow) + + s.lock.Lock() + c.myId = s.nextChanId + s.nextChanId++ + s.channels[c.myId] = c + s.lock.Unlock() + return c, nil + + case msgChannelRequest: + var chanRequest channelRequestMsg + if err := unmarshal(&chanRequest, packet, msgChannelRequest); err != nil { + return nil, err + } + + s.lock.Lock() + c, ok := s.channels[chanRequest.PeersId] + if !ok { + continue + } + c.handlePacket(&chanRequest) + s.lock.Unlock() + + case msgChannelData: + if len(packet) < 5 { + return nil, ParseError{msgChannelData} + } + chanId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4]) + + s.lock.Lock() + c, ok := s.channels[chanId] + if !ok { + continue + } + c.handleData(packet[9:]) + s.lock.Unlock() + + case msgChannelEOF: + var eofMsg channelEOFMsg + if err := unmarshal(&eofMsg, packet, msgChannelEOF); err != nil { + return nil, err + } + + s.lock.Lock() + c, ok := s.channels[eofMsg.PeersId] + if !ok { + continue + } + c.handlePacket(&eofMsg) + s.lock.Unlock() + + case msgChannelClose: + var closeMsg channelCloseMsg + if err := unmarshal(&closeMsg, packet, msgChannelClose); err != nil { + return nil, err + } + + s.lock.Lock() + c, ok := s.channels[closeMsg.PeersId] + if !ok { + continue + } + c.handlePacket(&closeMsg) + s.lock.Unlock() + + case msgGlobalRequest: + var request globalRequestMsg + if err := unmarshal(&request, packet, msgGlobalRequest); err != nil { + return nil, err + } + + if request.WantReply { + if err := s.writePacket([]byte{msgRequestFailure}); err != nil { + return nil, err + } + } + + default: + // Unknown message. Ignore. + } + } + + panic("unreachable") +} |