summaryrefslogtreecommitdiff
path: root/src/pkg/math/big/nat.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/math/big/nat.go')
-rw-r--r--src/pkg/math/big/nat.go309
1 files changed, 196 insertions, 113 deletions
diff --git a/src/pkg/math/big/nat.go b/src/pkg/math/big/nat.go
index eaa6ff066..9d09f97b7 100644
--- a/src/pkg/math/big/nat.go
+++ b/src/pkg/math/big/nat.go
@@ -236,7 +236,7 @@ func karatsubaSub(z, x nat, n int) {
// Operands that are shorter than karatsubaThreshold are multiplied using
// "grade school" multiplication; for longer operands the Karatsuba algorithm
// is used.
-var karatsubaThreshold int = 32 // computed by calibrate.go
+var karatsubaThreshold int = 40 // computed by calibrate.go
// karatsuba multiplies x and y and leaves the result in z.
// Both x and y must have the same length n and n must be a
@@ -342,7 +342,7 @@ func alias(x, y nat) bool {
return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
}
-// addAt implements z += x*(1<<(_W*i)); z must be long enough.
+// addAt implements z += x<<(_W*i); z must be long enough.
// (we don't use nat.add because we need z to stay the same
// slice, and we don't need to normalize z after each addition)
func addAt(z, x nat, i int) {
@@ -396,7 +396,7 @@ func (z nat) mul(x, y nat) nat {
}
// use basic multiplication if the numbers are small
- if n < karatsubaThreshold || n < 2 {
+ if n < karatsubaThreshold {
z = z.make(m + n)
basicMul(z, x, y)
return z.norm()
@@ -405,8 +405,8 @@ func (z nat) mul(x, y nat) nat {
// determine Karatsuba length k such that
//
- // x = x1*b + x0
- // y = y1*b + y0 (and k <= len(y), which implies k <= len(x))
+ // x = xh*b + x0 (0 <= x0 < b)
+ // y = yh*b + y0 (0 <= y0 < b)
// b = 1<<(_W*k) ("base" of digits xi, yi)
//
k := karatsubaLen(n)
@@ -417,27 +417,44 @@ func (z nat) mul(x, y nat) nat {
y0 := y[0:k] // y0 is not normalized
z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
karatsuba(z, x0, y0)
- z = z[0 : m+n] // z has final length but may be incomplete, upper portion is garbage
+ z = z[0 : m+n] // z has final length but may be incomplete
+ z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
- // If x1 and/or y1 are not 0, add missing terms to z explicitly:
+ // If xh != 0 or yh != 0, add the missing terms to z. For
+ //
+ // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
+ // yh = y1*b (0 <= y1 < b)
+ //
+ // the missing terms are
+ //
+ // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
//
- // m+n 2*k 0
- // z = [ ... | x0*y0 ]
- // + [ x1*y1 ]
- // + [ x1*y0 ]
- // + [ x0*y1 ]
+ // since all the yi for i > 1 are 0 by choice of k: If any of them
+ // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
+ // be a larger valid threshold contradicting the assumption about k.
//
if k < n || m != n {
- x1 := x[k:] // x1 is normalized because x is
- y1 := y[k:] // y1 is normalized because y is
var t nat
- t = t.mul(x1, y1)
- copy(z[2*k:], t)
- z[2*k+len(t):].clear() // upper portion of z is garbage
- t = t.mul(x1, y0.norm())
- addAt(z, t, k)
- t = t.mul(x0.norm(), y1)
+
+ // add x0*y1*b
+ x0 := x0.norm()
+ y1 := y[k:] // y1 is normalized because y is
+ t = t.mul(x0, y1) // update t so we don't lose t's underlying array
addAt(z, t, k)
+
+ // add xi*y0<<i, xi*y1*b<<(i+k)
+ y0 := y0.norm()
+ for i := k; i < len(x); i += k {
+ xi := x[i:]
+ if len(xi) > k {
+ xi = xi[:k]
+ }
+ xi = xi.norm()
+ t = t.mul(xi, y0)
+ addAt(z, t, i)
+ t = t.mul(xi, y1)
+ addAt(z, t, i+k)
+ }
}
return z.norm()
@@ -493,14 +510,9 @@ func (z nat) div(z2, u, v nat) (q, r nat) {
}
if len(v) == 1 {
- var rprime Word
- q, rprime = z.divW(u, v[0])
- if rprime > 0 {
- r = z2.make(1)
- r[0] = rprime
- } else {
- r = z2.make(0)
- }
+ var r2 Word
+ q, r2 = z.divW(u, v[0])
+ r = z2.setWord(r2)
return
}
@@ -740,7 +752,7 @@ func (x nat) string(charset string) string {
// convert power of two and non power of two bases separately
if b == b&-b {
// shift is base-b digit size in bits
- shift := uint(trailingZeroBits(b)) // shift > 0 because b >= 2
+ shift := trailingZeroBits(b) // shift > 0 because b >= 2
mask := Word(1)<<shift - 1
w := x[0]
nbits := uint(_W) // number of unprocessed bits in w
@@ -814,18 +826,18 @@ func (x nat) string(charset string) string {
// Convert words of q to base b digits in s. If q is large, it is recursively "split in half"
// by nat/nat division using tabulated divisors. Otherwise, it is converted iteratively using
-// repeated nat/Word divison.
+// repeated nat/Word division.
//
-// The iterative method processes n Words by n divW() calls, each of which visits every Word in the
-// incrementally shortened q for a total of n + (n-1) + (n-2) ... + 2 + 1, or n(n+1)/2 divW()'s.
-// Recursive conversion divides q by its approximate square root, yielding two parts, each half
+// The iterative method processes n Words by n divW() calls, each of which visits every Word in the
+// incrementally shortened q for a total of n + (n-1) + (n-2) ... + 2 + 1, or n(n+1)/2 divW()'s.
+// Recursive conversion divides q by its approximate square root, yielding two parts, each half
// the size of q. Using the iterative method on both halves means 2 * (n/2)(n/2 + 1)/2 divW()'s
// plus the expensive long div(). Asymptotically, the ratio is favorable at 1/2 the divW()'s, and
-// is made better by splitting the subblocks recursively. Best is to split blocks until one more
-// split would take longer (because of the nat/nat div()) than the twice as many divW()'s of the
-// iterative approach. This threshold is represented by leafSize. Benchmarking of leafSize in the
-// range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and
-// ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for
+// is made better by splitting the subblocks recursively. Best is to split blocks until one more
+// split would take longer (because of the nat/nat div()) than the twice as many divW()'s of the
+// iterative approach. This threshold is represented by leafSize. Benchmarking of leafSize in the
+// range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and
+// ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for
// specific hardware.
//
func (q nat) convertWords(s []byte, charset string, b Word, ndigits int, bb Word, table []divisor) {
@@ -908,8 +920,10 @@ type divisor struct {
ndigits int // digit length of divisor in terms of output base digits
}
-var cacheBase10 [64]divisor // cached divisors for base 10
-var cacheLock sync.Mutex // protects cacheBase10
+var cacheBase10 struct {
+ sync.Mutex
+ table [64]divisor // cached divisors for base 10
+}
// expWW computes x**y
func (z nat) expWW(x, y Word) nat {
@@ -925,34 +939,28 @@ func divisors(m int, b Word, ndigits int, bb Word) []divisor {
// determine k where (bb**leafSize)**(2**k) >= sqrt(x)
k := 1
- for words := leafSize; words < m>>1 && k < len(cacheBase10); words <<= 1 {
+ for words := leafSize; words < m>>1 && k < len(cacheBase10.table); words <<= 1 {
k++
}
- // create new table of divisors or extend and reuse existing table as appropriate
- var table []divisor
- var cached bool
- switch b {
- case 10:
- table = cacheBase10[0:k] // reuse old table for this conversion
- cached = true
- default:
- table = make([]divisor, k) // new table for this conversion
+ // reuse and extend existing table of divisors or create new table as appropriate
+ var table []divisor // for b == 10, table overlaps with cacheBase10.table
+ if b == 10 {
+ cacheBase10.Lock()
+ table = cacheBase10.table[0:k] // reuse old table for this conversion
+ } else {
+ table = make([]divisor, k) // create new table for this conversion
}
// extend table
if table[k-1].ndigits == 0 {
- if cached {
- cacheLock.Lock() // begin critical section
- }
-
// add new entries as needed
var larger nat
for i := 0; i < k; i++ {
if table[i].ndigits == 0 {
if i == 0 {
- table[i].bbb = nat(nil).expWW(bb, Word(leafSize))
- table[i].ndigits = ndigits * leafSize
+ table[0].bbb = nat(nil).expWW(bb, Word(leafSize))
+ table[0].ndigits = ndigits * leafSize
} else {
table[i].bbb = nat(nil).mul(table[i-1].bbb, table[i-1].bbb)
table[i].ndigits = 2 * table[i-1].ndigits
@@ -968,10 +976,10 @@ func divisors(m int, b Word, ndigits int, bb Word) []divisor {
table[i].nbits = table[i].bbb.bitLen()
}
}
+ }
- if cached {
- cacheLock.Unlock() // end critical section
- }
+ if b == 10 {
+ cacheBase10.Unlock()
}
return table
@@ -993,10 +1001,9 @@ var deBruijn64Lookup = []byte{
54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6,
}
-// trailingZeroBits returns the number of consecutive zero bits on the right
-// side of the given Word.
-// See Knuth, volume 4, section 7.3.1
-func trailingZeroBits(x Word) int {
+// trailingZeroBits returns the number of consecutive least significant zero
+// bits of x.
+func trailingZeroBits(x Word) uint {
// x & -x leaves only the right-most bit set in the word. Let k be the
// index of that bit. Since only a single bit is set, the value is two
// to the power of k. Multiplying by a power of two is equivalent to
@@ -1005,18 +1012,33 @@ func trailingZeroBits(x Word) int {
// Therefore, if we have a left shifted version of this constant we can
// find by how many bits it was shifted by looking at which six bit
// substring ended up at the top of the word.
+ // (Knuth, volume 4, section 7.3.1)
switch _W {
case 32:
- return int(deBruijn32Lookup[((x&-x)*deBruijn32)>>27])
+ return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27])
case 64:
- return int(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
+ return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
default:
- panic("Unknown word size")
+ panic("unknown word size")
}
return 0
}
+// trailingZeroBits returns the number of consecutive least significant zero
+// bits of x.
+func (x nat) trailingZeroBits() uint {
+ if len(x) == 0 {
+ return 0
+ }
+ var i uint
+ for x[i] == 0 {
+ i++
+ }
+ // x[i] != 0
+ return i*_W + trailingZeroBits(x[i])
+}
+
// z = x << s
func (z nat) shl(x nat, s uint) nat {
m := len(x)
@@ -1169,29 +1191,6 @@ func (x nat) modW(d Word) (r Word) {
return divWVW(q, 0, x, d)
}
-// powersOfTwoDecompose finds q and k with x = q * 1<<k and q is odd, or q and k are 0.
-func (x nat) powersOfTwoDecompose() (q nat, k int) {
- if len(x) == 0 {
- return x, 0
- }
-
- // One of the words must be non-zero by definition,
- // so this loop will terminate with i < len(x), and
- // i is the number of 0 words.
- i := 0
- for x[i] == 0 {
- i++
- }
- n := trailingZeroBits(x[i]) // x[i] != 0
-
- q = make(nat, len(x)-i)
- shrVU(q, x[i:], uint(n))
-
- q = q.norm()
- k = i*_W + n
- return
-}
-
// random creates a random integer in [0..limit), using the space in z if
// possible. n is the bit length of limit.
func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
@@ -1207,17 +1206,19 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
mask := Word((1 << bitLengthOfMSW) - 1)
for {
- for i := range z {
- switch _W {
- case 32:
+ switch _W {
+ case 32:
+ for i := range z {
z[i] = Word(rand.Uint32())
- case 64:
+ }
+ case 64:
+ for i := range z {
z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
}
+ default:
+ panic("unknown word size")
}
-
z[len(limit)-1] &= mask
-
if z.cmp(limit) < 0 {
break
}
@@ -1226,11 +1227,11 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
return z.norm()
}
-// If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It
-// reuses the storage of z if possible.
+// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
+// otherwise it sets z to x**y. The result is the value of z.
func (z nat) expNN(x, y, m nat) nat {
if alias(z, x) || alias(z, y) {
- // We cannot allow in place modification of x or y.
+ // We cannot allow in-place modification of x or y.
z = nil
}
@@ -1239,15 +1240,24 @@ func (z nat) expNN(x, y, m nat) nat {
z[0] = 1
return z
}
+ // y > 0
- if m != nil {
+ if len(m) != 0 {
// We likely end up being as long as the modulus.
z = z.make(len(m))
}
z = z.set(x)
- v := y[len(y)-1]
- // It's invalid for the most significant word to be zero, therefore we
- // will find a one bit.
+
+ // If the base is non-trivial and the exponent is large, we use
+ // 4-bit, windowed exponentiation. This involves precomputing 14 values
+ // (x^2...x^15) but then reduces the number of multiply-reduces by a
+ // third. Even for a 32-bit exponent, this reduces the number of
+ // operations.
+ if len(x) > 1 && len(y) > 1 && len(m) > 0 {
+ return z.expNNWindowed(x, y, m)
+ }
+
+ v := y[len(y)-1] // v > 0 because y is normalized and y > 0
shift := leadingZeros(v) + 1
v <<= shift
var q nat
@@ -1259,15 +1269,21 @@ func (z nat) expNN(x, y, m nat) nat {
// we also multiply by x, thus adding one to the power.
w := _W - int(shift)
+ // zz and r are used to avoid allocating in mul and div as
+ // otherwise the arguments would alias.
+ var zz, r nat
for j := 0; j < w; j++ {
- z = z.mul(z, z)
+ zz = zz.mul(z, z)
+ zz, z = z, zz
if v&mask != 0 {
- z = z.mul(z, x)
+ zz = zz.mul(z, x)
+ zz, z = z, zz
}
- if m != nil {
- q, z = q.div(z, z, m)
+ if len(m) != 0 {
+ zz, r = zz.div(r, z, m)
+ zz, r, q, z = q, z, zz, r
}
v <<= 1
@@ -1277,14 +1293,17 @@ func (z nat) expNN(x, y, m nat) nat {
v = y[i]
for j := 0; j < _W; j++ {
- z = z.mul(z, z)
+ zz = zz.mul(z, z)
+ zz, z = z, zz
if v&mask != 0 {
- z = z.mul(z, x)
+ zz = zz.mul(z, x)
+ zz, z = z, zz
}
- if m != nil {
- q, z = q.div(z, z, m)
+ if len(m) != 0 {
+ zz, r = zz.div(r, z, m)
+ zz, r, q, z = q, z, zz, r
}
v <<= 1
@@ -1294,6 +1313,69 @@ func (z nat) expNN(x, y, m nat) nat {
return z.norm()
}
+// expNNWindowed calculates x**y mod m using a fixed, 4-bit window.
+func (z nat) expNNWindowed(x, y, m nat) nat {
+ // zz and r are used to avoid allocating in mul and div as otherwise
+ // the arguments would alias.
+ var zz, r nat
+
+ const n = 4
+ // powers[i] contains x^i.
+ var powers [1 << n]nat
+ powers[0] = natOne
+ powers[1] = x
+ for i := 2; i < 1<<n; i += 2 {
+ p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
+ *p = p.mul(*p2, *p2)
+ zz, r = zz.div(r, *p, m)
+ *p, r = r, *p
+ *p1 = p1.mul(*p, x)
+ zz, r = zz.div(r, *p1, m)
+ *p1, r = r, *p1
+ }
+
+ z = z.setWord(1)
+
+ for i := len(y) - 1; i >= 0; i-- {
+ yi := y[i]
+ for j := 0; j < _W; j += n {
+ if i != len(y)-1 || j != 0 {
+ // Unrolled loop for significant performance
+ // gain. Use go test -bench=".*" in crypto/rsa
+ // to check performance before making changes.
+ zz = zz.mul(z, z)
+ zz, z = z, zz
+ zz, r = zz.div(r, z, m)
+ z, r = r, z
+
+ zz = zz.mul(z, z)
+ zz, z = z, zz
+ zz, r = zz.div(r, z, m)
+ z, r = r, z
+
+ zz = zz.mul(z, z)
+ zz, z = z, zz
+ zz, r = zz.div(r, z, m)
+ z, r = r, z
+
+ zz = zz.mul(z, z)
+ zz, z = z, zz
+ zz, r = zz.div(r, z, m)
+ z, r = r, z
+ }
+
+ zz = zz.mul(z, powers[yi>>(_W-n)])
+ zz, z = z, zz
+ zz, r = zz.div(r, z, m)
+ z, r = r, z
+
+ yi <<= n
+ }
+ }
+
+ return z.norm()
+}
+
// probablyPrime performs reps Miller-Rabin tests to check whether n is prime.
// If it returns true, n is prime with probability 1 - 1/4^reps.
// If it returns false, n is not prime.
@@ -1343,8 +1425,9 @@ func (n nat) probablyPrime(reps int) bool {
}
nm1 := nat(nil).sub(n, natOne)
- // 1<<k * q = nm1;
- q, k := nm1.powersOfTwoDecompose()
+ // determine q, k such that nm1 = q << k
+ k := nm1.trailingZeroBits()
+ q := nat(nil).shr(nm1, k)
nm3 := nat(nil).sub(nm1, natTwo)
rand := rand.New(rand.NewSource(int64(n[0])))
@@ -1360,7 +1443,7 @@ NextRandom:
if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
continue
}
- for j := 1; j < k; j++ {
+ for j := uint(1); j < k; j++ {
y = y.mul(y, y)
quotient, y = quotient.div(y, y, n)
if y.cmp(nm1) == 0 {