diff options
Diffstat (limited to 'src/pkg/math/big/nat.go')
-rw-r--r-- | src/pkg/math/big/nat.go | 309 |
1 files changed, 196 insertions, 113 deletions
diff --git a/src/pkg/math/big/nat.go b/src/pkg/math/big/nat.go index eaa6ff066..9d09f97b7 100644 --- a/src/pkg/math/big/nat.go +++ b/src/pkg/math/big/nat.go @@ -236,7 +236,7 @@ func karatsubaSub(z, x nat, n int) { // Operands that are shorter than karatsubaThreshold are multiplied using // "grade school" multiplication; for longer operands the Karatsuba algorithm // is used. -var karatsubaThreshold int = 32 // computed by calibrate.go +var karatsubaThreshold int = 40 // computed by calibrate.go // karatsuba multiplies x and y and leaves the result in z. // Both x and y must have the same length n and n must be a @@ -342,7 +342,7 @@ func alias(x, y nat) bool { return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] } -// addAt implements z += x*(1<<(_W*i)); z must be long enough. +// addAt implements z += x<<(_W*i); z must be long enough. // (we don't use nat.add because we need z to stay the same // slice, and we don't need to normalize z after each addition) func addAt(z, x nat, i int) { @@ -396,7 +396,7 @@ func (z nat) mul(x, y nat) nat { } // use basic multiplication if the numbers are small - if n < karatsubaThreshold || n < 2 { + if n < karatsubaThreshold { z = z.make(m + n) basicMul(z, x, y) return z.norm() @@ -405,8 +405,8 @@ func (z nat) mul(x, y nat) nat { // determine Karatsuba length k such that // - // x = x1*b + x0 - // y = y1*b + y0 (and k <= len(y), which implies k <= len(x)) + // x = xh*b + x0 (0 <= x0 < b) + // y = yh*b + y0 (0 <= y0 < b) // b = 1<<(_W*k) ("base" of digits xi, yi) // k := karatsubaLen(n) @@ -417,27 +417,44 @@ func (z nat) mul(x, y nat) nat { y0 := y[0:k] // y0 is not normalized z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y karatsuba(z, x0, y0) - z = z[0 : m+n] // z has final length but may be incomplete, upper portion is garbage + z = z[0 : m+n] // z has final length but may be incomplete + z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) - // If x1 and/or y1 are not 0, add missing terms to z explicitly: + // If xh != 0 or yh != 0, add the missing terms to z. For + // + // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) + // yh = y1*b (0 <= y1 < b) + // + // the missing terms are + // + // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 // - // m+n 2*k 0 - // z = [ ... | x0*y0 ] - // + [ x1*y1 ] - // + [ x1*y0 ] - // + [ x0*y1 ] + // since all the yi for i > 1 are 0 by choice of k: If any of them + // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would + // be a larger valid threshold contradicting the assumption about k. // if k < n || m != n { - x1 := x[k:] // x1 is normalized because x is - y1 := y[k:] // y1 is normalized because y is var t nat - t = t.mul(x1, y1) - copy(z[2*k:], t) - z[2*k+len(t):].clear() // upper portion of z is garbage - t = t.mul(x1, y0.norm()) - addAt(z, t, k) - t = t.mul(x0.norm(), y1) + + // add x0*y1*b + x0 := x0.norm() + y1 := y[k:] // y1 is normalized because y is + t = t.mul(x0, y1) // update t so we don't lose t's underlying array addAt(z, t, k) + + // add xi*y0<<i, xi*y1*b<<(i+k) + y0 := y0.norm() + for i := k; i < len(x); i += k { + xi := x[i:] + if len(xi) > k { + xi = xi[:k] + } + xi = xi.norm() + t = t.mul(xi, y0) + addAt(z, t, i) + t = t.mul(xi, y1) + addAt(z, t, i+k) + } } return z.norm() @@ -493,14 +510,9 @@ func (z nat) div(z2, u, v nat) (q, r nat) { } if len(v) == 1 { - var rprime Word - q, rprime = z.divW(u, v[0]) - if rprime > 0 { - r = z2.make(1) - r[0] = rprime - } else { - r = z2.make(0) - } + var r2 Word + q, r2 = z.divW(u, v[0]) + r = z2.setWord(r2) return } @@ -740,7 +752,7 @@ func (x nat) string(charset string) string { // convert power of two and non power of two bases separately if b == b&-b { // shift is base-b digit size in bits - shift := uint(trailingZeroBits(b)) // shift > 0 because b >= 2 + shift := trailingZeroBits(b) // shift > 0 because b >= 2 mask := Word(1)<<shift - 1 w := x[0] nbits := uint(_W) // number of unprocessed bits in w @@ -814,18 +826,18 @@ func (x nat) string(charset string) string { // Convert words of q to base b digits in s. If q is large, it is recursively "split in half" // by nat/nat division using tabulated divisors. Otherwise, it is converted iteratively using -// repeated nat/Word divison. +// repeated nat/Word division. // -// The iterative method processes n Words by n divW() calls, each of which visits every Word in the -// incrementally shortened q for a total of n + (n-1) + (n-2) ... + 2 + 1, or n(n+1)/2 divW()'s. -// Recursive conversion divides q by its approximate square root, yielding two parts, each half +// The iterative method processes n Words by n divW() calls, each of which visits every Word in the +// incrementally shortened q for a total of n + (n-1) + (n-2) ... + 2 + 1, or n(n+1)/2 divW()'s. +// Recursive conversion divides q by its approximate square root, yielding two parts, each half // the size of q. Using the iterative method on both halves means 2 * (n/2)(n/2 + 1)/2 divW()'s // plus the expensive long div(). Asymptotically, the ratio is favorable at 1/2 the divW()'s, and -// is made better by splitting the subblocks recursively. Best is to split blocks until one more -// split would take longer (because of the nat/nat div()) than the twice as many divW()'s of the -// iterative approach. This threshold is represented by leafSize. Benchmarking of leafSize in the -// range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and -// ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for +// is made better by splitting the subblocks recursively. Best is to split blocks until one more +// split would take longer (because of the nat/nat div()) than the twice as many divW()'s of the +// iterative approach. This threshold is represented by leafSize. Benchmarking of leafSize in the +// range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and +// ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for // specific hardware. // func (q nat) convertWords(s []byte, charset string, b Word, ndigits int, bb Word, table []divisor) { @@ -908,8 +920,10 @@ type divisor struct { ndigits int // digit length of divisor in terms of output base digits } -var cacheBase10 [64]divisor // cached divisors for base 10 -var cacheLock sync.Mutex // protects cacheBase10 +var cacheBase10 struct { + sync.Mutex + table [64]divisor // cached divisors for base 10 +} // expWW computes x**y func (z nat) expWW(x, y Word) nat { @@ -925,34 +939,28 @@ func divisors(m int, b Word, ndigits int, bb Word) []divisor { // determine k where (bb**leafSize)**(2**k) >= sqrt(x) k := 1 - for words := leafSize; words < m>>1 && k < len(cacheBase10); words <<= 1 { + for words := leafSize; words < m>>1 && k < len(cacheBase10.table); words <<= 1 { k++ } - // create new table of divisors or extend and reuse existing table as appropriate - var table []divisor - var cached bool - switch b { - case 10: - table = cacheBase10[0:k] // reuse old table for this conversion - cached = true - default: - table = make([]divisor, k) // new table for this conversion + // reuse and extend existing table of divisors or create new table as appropriate + var table []divisor // for b == 10, table overlaps with cacheBase10.table + if b == 10 { + cacheBase10.Lock() + table = cacheBase10.table[0:k] // reuse old table for this conversion + } else { + table = make([]divisor, k) // create new table for this conversion } // extend table if table[k-1].ndigits == 0 { - if cached { - cacheLock.Lock() // begin critical section - } - // add new entries as needed var larger nat for i := 0; i < k; i++ { if table[i].ndigits == 0 { if i == 0 { - table[i].bbb = nat(nil).expWW(bb, Word(leafSize)) - table[i].ndigits = ndigits * leafSize + table[0].bbb = nat(nil).expWW(bb, Word(leafSize)) + table[0].ndigits = ndigits * leafSize } else { table[i].bbb = nat(nil).mul(table[i-1].bbb, table[i-1].bbb) table[i].ndigits = 2 * table[i-1].ndigits @@ -968,10 +976,10 @@ func divisors(m int, b Word, ndigits int, bb Word) []divisor { table[i].nbits = table[i].bbb.bitLen() } } + } - if cached { - cacheLock.Unlock() // end critical section - } + if b == 10 { + cacheBase10.Unlock() } return table @@ -993,10 +1001,9 @@ var deBruijn64Lookup = []byte{ 54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6, } -// trailingZeroBits returns the number of consecutive zero bits on the right -// side of the given Word. -// See Knuth, volume 4, section 7.3.1 -func trailingZeroBits(x Word) int { +// trailingZeroBits returns the number of consecutive least significant zero +// bits of x. +func trailingZeroBits(x Word) uint { // x & -x leaves only the right-most bit set in the word. Let k be the // index of that bit. Since only a single bit is set, the value is two // to the power of k. Multiplying by a power of two is equivalent to @@ -1005,18 +1012,33 @@ func trailingZeroBits(x Word) int { // Therefore, if we have a left shifted version of this constant we can // find by how many bits it was shifted by looking at which six bit // substring ended up at the top of the word. + // (Knuth, volume 4, section 7.3.1) switch _W { case 32: - return int(deBruijn32Lookup[((x&-x)*deBruijn32)>>27]) + return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27]) case 64: - return int(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) + return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) default: - panic("Unknown word size") + panic("unknown word size") } return 0 } +// trailingZeroBits returns the number of consecutive least significant zero +// bits of x. +func (x nat) trailingZeroBits() uint { + if len(x) == 0 { + return 0 + } + var i uint + for x[i] == 0 { + i++ + } + // x[i] != 0 + return i*_W + trailingZeroBits(x[i]) +} + // z = x << s func (z nat) shl(x nat, s uint) nat { m := len(x) @@ -1169,29 +1191,6 @@ func (x nat) modW(d Word) (r Word) { return divWVW(q, 0, x, d) } -// powersOfTwoDecompose finds q and k with x = q * 1<<k and q is odd, or q and k are 0. -func (x nat) powersOfTwoDecompose() (q nat, k int) { - if len(x) == 0 { - return x, 0 - } - - // One of the words must be non-zero by definition, - // so this loop will terminate with i < len(x), and - // i is the number of 0 words. - i := 0 - for x[i] == 0 { - i++ - } - n := trailingZeroBits(x[i]) // x[i] != 0 - - q = make(nat, len(x)-i) - shrVU(q, x[i:], uint(n)) - - q = q.norm() - k = i*_W + n - return -} - // random creates a random integer in [0..limit), using the space in z if // possible. n is the bit length of limit. func (z nat) random(rand *rand.Rand, limit nat, n int) nat { @@ -1207,17 +1206,19 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat { mask := Word((1 << bitLengthOfMSW) - 1) for { - for i := range z { - switch _W { - case 32: + switch _W { + case 32: + for i := range z { z[i] = Word(rand.Uint32()) - case 64: + } + case 64: + for i := range z { z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32 } + default: + panic("unknown word size") } - z[len(limit)-1] &= mask - if z.cmp(limit) < 0 { break } @@ -1226,11 +1227,11 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat { return z.norm() } -// If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It -// reuses the storage of z if possible. +// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; +// otherwise it sets z to x**y. The result is the value of z. func (z nat) expNN(x, y, m nat) nat { if alias(z, x) || alias(z, y) { - // We cannot allow in place modification of x or y. + // We cannot allow in-place modification of x or y. z = nil } @@ -1239,15 +1240,24 @@ func (z nat) expNN(x, y, m nat) nat { z[0] = 1 return z } + // y > 0 - if m != nil { + if len(m) != 0 { // We likely end up being as long as the modulus. z = z.make(len(m)) } z = z.set(x) - v := y[len(y)-1] - // It's invalid for the most significant word to be zero, therefore we - // will find a one bit. + + // If the base is non-trivial and the exponent is large, we use + // 4-bit, windowed exponentiation. This involves precomputing 14 values + // (x^2...x^15) but then reduces the number of multiply-reduces by a + // third. Even for a 32-bit exponent, this reduces the number of + // operations. + if len(x) > 1 && len(y) > 1 && len(m) > 0 { + return z.expNNWindowed(x, y, m) + } + + v := y[len(y)-1] // v > 0 because y is normalized and y > 0 shift := leadingZeros(v) + 1 v <<= shift var q nat @@ -1259,15 +1269,21 @@ func (z nat) expNN(x, y, m nat) nat { // we also multiply by x, thus adding one to the power. w := _W - int(shift) + // zz and r are used to avoid allocating in mul and div as + // otherwise the arguments would alias. + var zz, r nat for j := 0; j < w; j++ { - z = z.mul(z, z) + zz = zz.mul(z, z) + zz, z = z, zz if v&mask != 0 { - z = z.mul(z, x) + zz = zz.mul(z, x) + zz, z = z, zz } - if m != nil { - q, z = q.div(z, z, m) + if len(m) != 0 { + zz, r = zz.div(r, z, m) + zz, r, q, z = q, z, zz, r } v <<= 1 @@ -1277,14 +1293,17 @@ func (z nat) expNN(x, y, m nat) nat { v = y[i] for j := 0; j < _W; j++ { - z = z.mul(z, z) + zz = zz.mul(z, z) + zz, z = z, zz if v&mask != 0 { - z = z.mul(z, x) + zz = zz.mul(z, x) + zz, z = z, zz } - if m != nil { - q, z = q.div(z, z, m) + if len(m) != 0 { + zz, r = zz.div(r, z, m) + zz, r, q, z = q, z, zz, r } v <<= 1 @@ -1294,6 +1313,69 @@ func (z nat) expNN(x, y, m nat) nat { return z.norm() } +// expNNWindowed calculates x**y mod m using a fixed, 4-bit window. +func (z nat) expNNWindowed(x, y, m nat) nat { + // zz and r are used to avoid allocating in mul and div as otherwise + // the arguments would alias. + var zz, r nat + + const n = 4 + // powers[i] contains x^i. + var powers [1 << n]nat + powers[0] = natOne + powers[1] = x + for i := 2; i < 1<<n; i += 2 { + p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1] + *p = p.mul(*p2, *p2) + zz, r = zz.div(r, *p, m) + *p, r = r, *p + *p1 = p1.mul(*p, x) + zz, r = zz.div(r, *p1, m) + *p1, r = r, *p1 + } + + z = z.setWord(1) + + for i := len(y) - 1; i >= 0; i-- { + yi := y[i] + for j := 0; j < _W; j += n { + if i != len(y)-1 || j != 0 { + // Unrolled loop for significant performance + // gain. Use go test -bench=".*" in crypto/rsa + // to check performance before making changes. + zz = zz.mul(z, z) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + + zz = zz.mul(z, z) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + + zz = zz.mul(z, z) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + + zz = zz.mul(z, z) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + } + + zz = zz.mul(z, powers[yi>>(_W-n)]) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + + yi <<= n + } + } + + return z.norm() +} + // probablyPrime performs reps Miller-Rabin tests to check whether n is prime. // If it returns true, n is prime with probability 1 - 1/4^reps. // If it returns false, n is not prime. @@ -1343,8 +1425,9 @@ func (n nat) probablyPrime(reps int) bool { } nm1 := nat(nil).sub(n, natOne) - // 1<<k * q = nm1; - q, k := nm1.powersOfTwoDecompose() + // determine q, k such that nm1 = q << k + k := nm1.trailingZeroBits() + q := nat(nil).shr(nm1, k) nm3 := nat(nil).sub(nm1, natTwo) rand := rand.New(rand.NewSource(int64(n[0]))) @@ -1360,7 +1443,7 @@ NextRandom: if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 { continue } - for j := 1; j < k; j++ { + for j := uint(1); j < k; j++ { y = y.mul(y, y) quotient, y = quotient.div(y, y, n) if y.cmp(nm1) == 0 { |