diff options
author | Adam Langley <agl@golang.org> | 2009-10-19 11:52:44 -0700 |
---|---|---|
committer | Adam Langley <agl@golang.org> | 2009-10-19 11:52:44 -0700 |
commit | a17a6f7bbd15f66146bff61e845bf1f707b81470 (patch) | |
tree | b50dd67613da6ea6ca40005c6ce1c10ac915c482 /src/pkg/crypto/rsa/rsa.go | |
parent | 44a290f13fb2efa7fbc09ef19a1cf468224080ab (diff) | |
download | golang-a17a6f7bbd15f66146bff61e845bf1f707b81470.tar.gz |
Add an RSA-OAEP implementation.
R=rsc
APPROVED=rsc
DELTA=734 (734 added, 0 deleted, 0 changed)
OCL=35738
CL=35879
Diffstat (limited to 'src/pkg/crypto/rsa/rsa.go')
-rw-r--r-- | src/pkg/crypto/rsa/rsa.go | 413 |
1 files changed, 413 insertions, 0 deletions
diff --git a/src/pkg/crypto/rsa/rsa.go b/src/pkg/crypto/rsa/rsa.go new file mode 100644 index 000000000..de98d3074 --- /dev/null +++ b/src/pkg/crypto/rsa/rsa.go @@ -0,0 +1,413 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This package implements RSA encryption as specified in PKCS#1. +package rsa + +// TODO(agl): Add support for PSS padding. + +import ( + "bytes"; + big "gmp"; + "hash"; + "io"; + "os"; +) + +var bigOne = big.NewInt(1) + +// randomSafePrime returns a number, p, of the given size, such that p and +// (p-1)/2 are both prime with high probability. +func randomSafePrime(rand io.Reader, bits int) (p *big.Int, err os.Error) { + if bits < 1 { + err = os.EINVAL; + } + + bytes := make([]byte, (bits+7)/8); + p = new(big.Int); + p2 := new(big.Int); + + for { + _, err = io.ReadFull(rand, bytes); + if err != nil { + return; + } + + // Don't let the value be too small. + bytes[0] |= 0x80; + // Make the value odd since an even number this large certainly isn't prime. + bytes[len(bytes)-1] |= 1; + + p.SetBytes(bytes); + if p.ProbablyPrime(20) { + p2.Rsh(p, 1); // p2 = (p - 1)/2 + if p2.ProbablyPrime(20) { + return; + } + } + } + + return; +} + +// randomNumber returns a uniform random value in [0, max). +func randomNumber(rand io.Reader, max *big.Int) (n *big.Int, err os.Error) { + k := (max.Len() + 7)/8; + + // r is the number of bits in the used in the most significant byte of + // max. + r := uint(max.Len() % 8); + if r == 0 { + r = 8; + } + + bytes := make([]byte, k); + n = new(big.Int); + + for { + _, err = io.ReadFull(rand, bytes); + if err != nil { + return; + } + + // Clear bits in the first byte to increase the probability + // that the candidate is < max. + bytes[0] &= uint8(int(1<<r)-1); + + n.SetBytes(bytes); + if big.CmpInt(n, max) < 0 { + return; + } + } + + return; +} + +// A PublicKey represents the public part of an RSA key. +type PublicKey struct { + N *big.Int; // modulus + E int; // public exponent +} + +// A PrivateKey represents an RSA key +type PrivateKey struct { + PublicKey; // public part. + D *big.Int; // private exponent + P, Q *big.Int; // prime factors of N +} + +// GenerateKeyPair generates an RSA keypair of the given bit size. +func GenerateKey(rand io.Reader, bits int) (priv *PrivateKey, err os.Error) { + priv = new(PrivateKey); + // Smaller public exponents lead to faster public key + // operations. Since the exponent must be coprime to + // (p-1)(q-1), the smallest possible value is 3. Some have + // suggested that a larger exponent (often 2**16+1) be used + // since previous implementation bugs[1] were avoided when this + // was the case. However, there are no current reasons not to use + // small exponents. + // [1] http://marc.info/?l=cryptography&m=115694833312008&w=2 + priv.E = 3; + + pminus1 := new(big.Int); + qminus1 := new(big.Int); + totient := new(big.Int); + + for { + p, err := randomSafePrime(rand, bits/2); + if err != nil { + return; + } + + q, err := randomSafePrime(rand, bits/2); + if err != nil { + return; + } + + if big.CmpInt(p, q) == 0 { + continue; + } + + n := new(big.Int).Mul(p, q); + pminus1.Sub(p, bigOne); + qminus1.Sub(q, bigOne); + totient.Mul(pminus1, qminus1); + + g := new(big.Int); + priv.D = new(big.Int); + y := new(big.Int); + e := big.NewInt(int64(priv.E)); + big.GcdInt(g, priv.D, y, e, totient); + + if big.CmpInt(g, bigOne) == 0 { + priv.D.Add(priv.D, totient); + priv.P = p; + priv.Q = q; + priv.N = n; + + break; + } + } + + return; +} + +// incCounter increments a four byte, big-endian counter. +func incCounter(c *[4]byte) { + if c[3]++; c[3] != 0 { + return; + } + if c[2]++; c[2] != 0 { + return; + } + if c[1]++; c[1] != 0 { + return; + } + c[0]++; +} + +// mgf1XOR XORs the bytes in out with a mask generated using the MGF1 function +// specified in PKCS#1 v2.1. +func mgf1XOR(out []byte, hash hash.Hash, seed []byte) { + var counter [4]byte; + + done := 0; + for done < len(out) { + hash.Write(seed); + hash.Write(counter[0:4]); + digest := hash.Sum(); + hash.Reset(); + + for i := 0; i < len(digest) && done < len(out); i++ { + out[done] ^= digest[i]; + done++; + } + incCounter(&counter); + } +} + +// MessageTooLongError is returned when attempting to encrypt a message which +// is too large for the size of the public key. +type MessageTooLongError struct{} + +func (MessageTooLongError) String() string { + return "message too long for RSA public key size"; +} + +func encrypt(c *big.Int, pub *PublicKey, m *big.Int) *big.Int { + e := big.NewInt(int64(pub.E)); + c.Exp(m, e, pub.N); + return c; +} + +// EncryptOAEP encrypts the given message with RSA-OAEP. +// The message must be no longer than the length of the public modulus less +// twice the hash length plus 2. +func EncryptOAEP(hash hash.Hash, rand io.Reader, pub *PublicKey, msg []byte, label []byte) (out []byte, err os.Error) { + hash.Reset(); + k := (pub.N.Len() + 7)/8; + if len(msg) > k - 2 * hash.Size() - 2 { + err = MessageTooLongError{}; + return; + } + + hash.Write(label); + lHash := hash.Sum(); + hash.Reset(); + + em := make([]byte, k); + seed := em[1 : 1 + hash.Size()]; + db := em[1 + hash.Size() : len(em)]; + + bytes.Copy(db[0 : hash.Size()], lHash); + db[len(db)-len(msg)-1] = 1; + bytes.Copy(db[len(db)-len(msg) : len(db)], msg); + + _, err = io.ReadFull(rand, seed); + if err != nil { + return; + } + + mgf1XOR(db, hash, seed); + mgf1XOR(seed, hash, db); + + m := new(big.Int); + m.SetBytes(em); + c := encrypt(new(big.Int), pub, m); + out = c.Bytes(); + return; +} + +// A DecryptionError represents a failure to decrypt a message. +// It is deliberately vague to avoid adaptive attacks. +type DecryptionError struct{} + +func (DecryptionError) String() string { + return "RSA decryption error"; +} + +// modInverse returns ia, the inverse of a in the multiplicative group of prime +// order n. It requires that a be a member of the group (i.e. less than n). +func modInverse(a, n *big.Int) (ia *big.Int) { + g := new(big.Int); + x := new(big.Int); + y := new(big.Int); + big.GcdInt(g, x, y, a, n); + if big.CmpInt(x, bigOne) < 0 { + // 0 is not the multiplicative inverse of any element so, if x + // < 1, then x is negative. + x.Add(x, n); + } + + return x; +} + +// constantTimeCompare returns 1 iff the two equal length slices, x +// and y, have equal contents. The time taken is a function of the length of +// the slices and is independent of the contents. +func constantTimeCompare(x, y []byte) int { + var v byte; + + for i := 0; i < len(x); i++ { + v |= x[i]^y[i]; + } + + return constantTimeByteEq(v, 0); +} + +// constantTimeSelect returns a if mask is 1 and b if mask is 0. +// Its behaviour is undefined if mask takes any other value. +func constantTimeSelect(mask, a, b int) int { + return ^(mask-1)&a | (mask-1)&b; +} + +// constantTimeByteEq returns 1 if a == b and 0 otherwise. +func constantTimeByteEq(a, b uint8) (mask int) { + x := ^(a^b); + x &= x>>4; + x &= x>>2; + x &= x>>1; + + return int(x); +} + +// decrypt performs an RSA decryption, resulting in a plaintext integer. If a +// random source is given, RSA blinding is used. +func decrypt(rand io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err os.Error) { + // TODO(agl): can we get away with reusing blinds? + if big.CmpInt(c, priv.N) > 0 { + err = DecryptionError{}; + return; + } + + var ir *big.Int; + if rand != nil { + // Blinding enabled. Blinding involves multiplying c by r^e. + // Then the decryption operation performs (m^e * r^e)^d mod n + // which equals mr mod n. The factor of r can then be removed + // by multipling by the multiplicative inverse of r. + + r, err := randomNumber(rand, priv.N); + if err != nil { + return; + } + ir = modInverse(r, priv.N); + bigE := big.NewInt(int64(priv.E)); + rpowe := new(big.Int).Exp(r, bigE, priv.N); + c.Mul(c, rpowe); + c.Mod(c, priv.N); + } + + m = new(big.Int).Exp(c, priv.D, priv.N); + + if ir != nil { + // Unblind. + m.Mul(m, ir); + m.Mod(m, priv.N); + } + + return; +} + +// DecryptOAEP decrypts ciphertext using RSA-OAEP. +// If rand != nil, DecryptOAEP uses RSA blinding to avoid timing side-channel attacks. +func DecryptOAEP(hash hash.Hash, rand io.Reader, priv *PrivateKey, ciphertext []byte, label []byte) (msg []byte, err os.Error) { + k := (priv.N.Len() + 7)/8; + if len(ciphertext) > k || + k < hash.Size() * 2 + 2 { + err = DecryptionError{}; + return; + } + + c := new(big.Int).SetBytes(ciphertext); + + m, err := decrypt(rand, priv, c); + if err != nil { + return; + } + + hash.Write(label); + lHash := hash.Sum(); + hash.Reset(); + + // Converting the plaintext number to bytes will strip any + // leading zeros so we may have to left pad. We do this unconditionally + // to avoid leaking timing information. (Although we still probably + // leak the number of leading zeros. It's not clear that we can do + // anything about this.) + em := leftPad(m.Bytes(), k); + + firstByteIsZero := constantTimeByteEq(em[0], 0); + + seed := em[1 : hash.Size() + 1]; + db := em[hash.Size() + 1 : len(em)]; + + mgf1XOR(seed, hash, db); + mgf1XOR(db, hash, seed); + + lHash2 := db[0 : hash.Size()]; + + // We have to validate the plaintext in contanst time in order to avoid + // attacks like: J. Manger. A Chosen Ciphertext Attack on RSA Optimal + // Asymmetric Encryption Padding (OAEP) as Standardized in PKCS #1 + // v2.0. In J. Kilian, editor, Advances in Cryptology. + lHash2Good := constantTimeCompare(lHash, lHash2); + + // The remainder of the plaintext must be zero or more 0x00, followed + // by 0x01, followed by the message. + // lookingForIndex: 1 iff we are still looking for the 0x01 + // index: the offset of the first 0x01 byte + // invalid: 1 iff we saw a non-zero byte before the 0x01. + var lookingForIndex, index, invalid int; + lookingForIndex = 1; + rest := db[hash.Size() : len(db)]; + + for i := 0; i < len(rest); i++ { + equals0 := constantTimeByteEq(rest[i], 0); + equals1 := constantTimeByteEq(rest[i], 1); + index = constantTimeSelect(lookingForIndex & equals1, i, index); + lookingForIndex = constantTimeSelect(equals1, 0, lookingForIndex); + invalid = constantTimeSelect(lookingForIndex & ^equals0, 1, invalid); + } + + if firstByteIsZero & lHash2Good & ^invalid & ^lookingForIndex != 1 { + err = DecryptionError{}; + return; + } + + msg = rest[index+1 : len(rest)]; + return; +} + +// leftPad returns a new slice of length size. The contents of input are right +// aligned in the new slice. +func leftPad(input []byte, size int) (out []byte) { + n := len(input); + if n > size { + n = size; + } + out = make([]byte, size); + bytes.Copy(out[len(out)-n : len(out)], input); + return; +} |