summaryrefslogtreecommitdiff
path: root/src/pkg/crypto/rsa/rsa.go
diff options
context:
space:
mode:
authorAdam Langley <agl@golang.org>2009-10-19 11:52:44 -0700
committerAdam Langley <agl@golang.org>2009-10-19 11:52:44 -0700
commita17a6f7bbd15f66146bff61e845bf1f707b81470 (patch)
treeb50dd67613da6ea6ca40005c6ce1c10ac915c482 /src/pkg/crypto/rsa/rsa.go
parent44a290f13fb2efa7fbc09ef19a1cf468224080ab (diff)
downloadgolang-a17a6f7bbd15f66146bff61e845bf1f707b81470.tar.gz
Add an RSA-OAEP implementation.
R=rsc APPROVED=rsc DELTA=734 (734 added, 0 deleted, 0 changed) OCL=35738 CL=35879
Diffstat (limited to 'src/pkg/crypto/rsa/rsa.go')
-rw-r--r--src/pkg/crypto/rsa/rsa.go413
1 files changed, 413 insertions, 0 deletions
diff --git a/src/pkg/crypto/rsa/rsa.go b/src/pkg/crypto/rsa/rsa.go
new file mode 100644
index 000000000..de98d3074
--- /dev/null
+++ b/src/pkg/crypto/rsa/rsa.go
@@ -0,0 +1,413 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package implements RSA encryption as specified in PKCS#1.
+package rsa
+
+// TODO(agl): Add support for PSS padding.
+
+import (
+ "bytes";
+ big "gmp";
+ "hash";
+ "io";
+ "os";
+)
+
+var bigOne = big.NewInt(1)
+
+// randomSafePrime returns a number, p, of the given size, such that p and
+// (p-1)/2 are both prime with high probability.
+func randomSafePrime(rand io.Reader, bits int) (p *big.Int, err os.Error) {
+ if bits < 1 {
+ err = os.EINVAL;
+ }
+
+ bytes := make([]byte, (bits+7)/8);
+ p = new(big.Int);
+ p2 := new(big.Int);
+
+ for {
+ _, err = io.ReadFull(rand, bytes);
+ if err != nil {
+ return;
+ }
+
+ // Don't let the value be too small.
+ bytes[0] |= 0x80;
+ // Make the value odd since an even number this large certainly isn't prime.
+ bytes[len(bytes)-1] |= 1;
+
+ p.SetBytes(bytes);
+ if p.ProbablyPrime(20) {
+ p2.Rsh(p, 1); // p2 = (p - 1)/2
+ if p2.ProbablyPrime(20) {
+ return;
+ }
+ }
+ }
+
+ return;
+}
+
+// randomNumber returns a uniform random value in [0, max).
+func randomNumber(rand io.Reader, max *big.Int) (n *big.Int, err os.Error) {
+ k := (max.Len() + 7)/8;
+
+ // r is the number of bits in the used in the most significant byte of
+ // max.
+ r := uint(max.Len() % 8);
+ if r == 0 {
+ r = 8;
+ }
+
+ bytes := make([]byte, k);
+ n = new(big.Int);
+
+ for {
+ _, err = io.ReadFull(rand, bytes);
+ if err != nil {
+ return;
+ }
+
+ // Clear bits in the first byte to increase the probability
+ // that the candidate is < max.
+ bytes[0] &= uint8(int(1<<r)-1);
+
+ n.SetBytes(bytes);
+ if big.CmpInt(n, max) < 0 {
+ return;
+ }
+ }
+
+ return;
+}
+
+// A PublicKey represents the public part of an RSA key.
+type PublicKey struct {
+ N *big.Int; // modulus
+ E int; // public exponent
+}
+
+// A PrivateKey represents an RSA key
+type PrivateKey struct {
+ PublicKey; // public part.
+ D *big.Int; // private exponent
+ P, Q *big.Int; // prime factors of N
+}
+
+// GenerateKeyPair generates an RSA keypair of the given bit size.
+func GenerateKey(rand io.Reader, bits int) (priv *PrivateKey, err os.Error) {
+ priv = new(PrivateKey);
+ // Smaller public exponents lead to faster public key
+ // operations. Since the exponent must be coprime to
+ // (p-1)(q-1), the smallest possible value is 3. Some have
+ // suggested that a larger exponent (often 2**16+1) be used
+ // since previous implementation bugs[1] were avoided when this
+ // was the case. However, there are no current reasons not to use
+ // small exponents.
+ // [1] http://marc.info/?l=cryptography&m=115694833312008&w=2
+ priv.E = 3;
+
+ pminus1 := new(big.Int);
+ qminus1 := new(big.Int);
+ totient := new(big.Int);
+
+ for {
+ p, err := randomSafePrime(rand, bits/2);
+ if err != nil {
+ return;
+ }
+
+ q, err := randomSafePrime(rand, bits/2);
+ if err != nil {
+ return;
+ }
+
+ if big.CmpInt(p, q) == 0 {
+ continue;
+ }
+
+ n := new(big.Int).Mul(p, q);
+ pminus1.Sub(p, bigOne);
+ qminus1.Sub(q, bigOne);
+ totient.Mul(pminus1, qminus1);
+
+ g := new(big.Int);
+ priv.D = new(big.Int);
+ y := new(big.Int);
+ e := big.NewInt(int64(priv.E));
+ big.GcdInt(g, priv.D, y, e, totient);
+
+ if big.CmpInt(g, bigOne) == 0 {
+ priv.D.Add(priv.D, totient);
+ priv.P = p;
+ priv.Q = q;
+ priv.N = n;
+
+ break;
+ }
+ }
+
+ return;
+}
+
+// incCounter increments a four byte, big-endian counter.
+func incCounter(c *[4]byte) {
+ if c[3]++; c[3] != 0 {
+ return;
+ }
+ if c[2]++; c[2] != 0 {
+ return;
+ }
+ if c[1]++; c[1] != 0 {
+ return;
+ }
+ c[0]++;
+}
+
+// mgf1XOR XORs the bytes in out with a mask generated using the MGF1 function
+// specified in PKCS#1 v2.1.
+func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
+ var counter [4]byte;
+
+ done := 0;
+ for done < len(out) {
+ hash.Write(seed);
+ hash.Write(counter[0:4]);
+ digest := hash.Sum();
+ hash.Reset();
+
+ for i := 0; i < len(digest) && done < len(out); i++ {
+ out[done] ^= digest[i];
+ done++;
+ }
+ incCounter(&counter);
+ }
+}
+
+// MessageTooLongError is returned when attempting to encrypt a message which
+// is too large for the size of the public key.
+type MessageTooLongError struct{}
+
+func (MessageTooLongError) String() string {
+ return "message too long for RSA public key size";
+}
+
+func encrypt(c *big.Int, pub *PublicKey, m *big.Int) *big.Int {
+ e := big.NewInt(int64(pub.E));
+ c.Exp(m, e, pub.N);
+ return c;
+}
+
+// EncryptOAEP encrypts the given message with RSA-OAEP.
+// The message must be no longer than the length of the public modulus less
+// twice the hash length plus 2.
+func EncryptOAEP(hash hash.Hash, rand io.Reader, pub *PublicKey, msg []byte, label []byte) (out []byte, err os.Error) {
+ hash.Reset();
+ k := (pub.N.Len() + 7)/8;
+ if len(msg) > k - 2 * hash.Size() - 2 {
+ err = MessageTooLongError{};
+ return;
+ }
+
+ hash.Write(label);
+ lHash := hash.Sum();
+ hash.Reset();
+
+ em := make([]byte, k);
+ seed := em[1 : 1 + hash.Size()];
+ db := em[1 + hash.Size() : len(em)];
+
+ bytes.Copy(db[0 : hash.Size()], lHash);
+ db[len(db)-len(msg)-1] = 1;
+ bytes.Copy(db[len(db)-len(msg) : len(db)], msg);
+
+ _, err = io.ReadFull(rand, seed);
+ if err != nil {
+ return;
+ }
+
+ mgf1XOR(db, hash, seed);
+ mgf1XOR(seed, hash, db);
+
+ m := new(big.Int);
+ m.SetBytes(em);
+ c := encrypt(new(big.Int), pub, m);
+ out = c.Bytes();
+ return;
+}
+
+// A DecryptionError represents a failure to decrypt a message.
+// It is deliberately vague to avoid adaptive attacks.
+type DecryptionError struct{}
+
+func (DecryptionError) String() string {
+ return "RSA decryption error";
+}
+
+// modInverse returns ia, the inverse of a in the multiplicative group of prime
+// order n. It requires that a be a member of the group (i.e. less than n).
+func modInverse(a, n *big.Int) (ia *big.Int) {
+ g := new(big.Int);
+ x := new(big.Int);
+ y := new(big.Int);
+ big.GcdInt(g, x, y, a, n);
+ if big.CmpInt(x, bigOne) < 0 {
+ // 0 is not the multiplicative inverse of any element so, if x
+ // < 1, then x is negative.
+ x.Add(x, n);
+ }
+
+ return x;
+}
+
+// constantTimeCompare returns 1 iff the two equal length slices, x
+// and y, have equal contents. The time taken is a function of the length of
+// the slices and is independent of the contents.
+func constantTimeCompare(x, y []byte) int {
+ var v byte;
+
+ for i := 0; i < len(x); i++ {
+ v |= x[i]^y[i];
+ }
+
+ return constantTimeByteEq(v, 0);
+}
+
+// constantTimeSelect returns a if mask is 1 and b if mask is 0.
+// Its behaviour is undefined if mask takes any other value.
+func constantTimeSelect(mask, a, b int) int {
+ return ^(mask-1)&a | (mask-1)&b;
+}
+
+// constantTimeByteEq returns 1 if a == b and 0 otherwise.
+func constantTimeByteEq(a, b uint8) (mask int) {
+ x := ^(a^b);
+ x &= x>>4;
+ x &= x>>2;
+ x &= x>>1;
+
+ return int(x);
+}
+
+// decrypt performs an RSA decryption, resulting in a plaintext integer. If a
+// random source is given, RSA blinding is used.
+func decrypt(rand io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err os.Error) {
+ // TODO(agl): can we get away with reusing blinds?
+ if big.CmpInt(c, priv.N) > 0 {
+ err = DecryptionError{};
+ return;
+ }
+
+ var ir *big.Int;
+ if rand != nil {
+ // Blinding enabled. Blinding involves multiplying c by r^e.
+ // Then the decryption operation performs (m^e * r^e)^d mod n
+ // which equals mr mod n. The factor of r can then be removed
+ // by multipling by the multiplicative inverse of r.
+
+ r, err := randomNumber(rand, priv.N);
+ if err != nil {
+ return;
+ }
+ ir = modInverse(r, priv.N);
+ bigE := big.NewInt(int64(priv.E));
+ rpowe := new(big.Int).Exp(r, bigE, priv.N);
+ c.Mul(c, rpowe);
+ c.Mod(c, priv.N);
+ }
+
+ m = new(big.Int).Exp(c, priv.D, priv.N);
+
+ if ir != nil {
+ // Unblind.
+ m.Mul(m, ir);
+ m.Mod(m, priv.N);
+ }
+
+ return;
+}
+
+// DecryptOAEP decrypts ciphertext using RSA-OAEP.
+// If rand != nil, DecryptOAEP uses RSA blinding to avoid timing side-channel attacks.
+func DecryptOAEP(hash hash.Hash, rand io.Reader, priv *PrivateKey, ciphertext []byte, label []byte) (msg []byte, err os.Error) {
+ k := (priv.N.Len() + 7)/8;
+ if len(ciphertext) > k ||
+ k < hash.Size() * 2 + 2 {
+ err = DecryptionError{};
+ return;
+ }
+
+ c := new(big.Int).SetBytes(ciphertext);
+
+ m, err := decrypt(rand, priv, c);
+ if err != nil {
+ return;
+ }
+
+ hash.Write(label);
+ lHash := hash.Sum();
+ hash.Reset();
+
+ // Converting the plaintext number to bytes will strip any
+ // leading zeros so we may have to left pad. We do this unconditionally
+ // to avoid leaking timing information. (Although we still probably
+ // leak the number of leading zeros. It's not clear that we can do
+ // anything about this.)
+ em := leftPad(m.Bytes(), k);
+
+ firstByteIsZero := constantTimeByteEq(em[0], 0);
+
+ seed := em[1 : hash.Size() + 1];
+ db := em[hash.Size() + 1 : len(em)];
+
+ mgf1XOR(seed, hash, db);
+ mgf1XOR(db, hash, seed);
+
+ lHash2 := db[0 : hash.Size()];
+
+ // We have to validate the plaintext in contanst time in order to avoid
+ // attacks like: J. Manger. A Chosen Ciphertext Attack on RSA Optimal
+ // Asymmetric Encryption Padding (OAEP) as Standardized in PKCS #1
+ // v2.0. In J. Kilian, editor, Advances in Cryptology.
+ lHash2Good := constantTimeCompare(lHash, lHash2);
+
+ // The remainder of the plaintext must be zero or more 0x00, followed
+ // by 0x01, followed by the message.
+ // lookingForIndex: 1 iff we are still looking for the 0x01
+ // index: the offset of the first 0x01 byte
+ // invalid: 1 iff we saw a non-zero byte before the 0x01.
+ var lookingForIndex, index, invalid int;
+ lookingForIndex = 1;
+ rest := db[hash.Size() : len(db)];
+
+ for i := 0; i < len(rest); i++ {
+ equals0 := constantTimeByteEq(rest[i], 0);
+ equals1 := constantTimeByteEq(rest[i], 1);
+ index = constantTimeSelect(lookingForIndex & equals1, i, index);
+ lookingForIndex = constantTimeSelect(equals1, 0, lookingForIndex);
+ invalid = constantTimeSelect(lookingForIndex & ^equals0, 1, invalid);
+ }
+
+ if firstByteIsZero & lHash2Good & ^invalid & ^lookingForIndex != 1 {
+ err = DecryptionError{};
+ return;
+ }
+
+ msg = rest[index+1 : len(rest)];
+ return;
+}
+
+// leftPad returns a new slice of length size. The contents of input are right
+// aligned in the new slice.
+func leftPad(input []byte, size int) (out []byte) {
+ n := len(input);
+ if n > size {
+ n = size;
+ }
+ out = make([]byte, size);
+ bytes.Copy(out[len(out)-n : len(out)], input);
+ return;
+}