diff options
Diffstat (limited to 'src/pkg/big')
-rwxr-xr-x | src/pkg/big/int.go | 94 | ||||
-rwxr-xr-x | src/pkg/big/int_test.go | 197 | ||||
-rwxr-xr-x | src/pkg/big/nat.go | 75 | ||||
-rwxr-xr-x | src/pkg/big/nat_test.go | 32 | ||||
-rw-r--r-- | src/pkg/big/rat.go | 8 |
5 files changed, 361 insertions, 45 deletions
diff --git a/src/pkg/big/int.go b/src/pkg/big/int.go index f1ea7b1c2..74fbef48d 100755 --- a/src/pkg/big/int.go +++ b/src/pkg/big/int.go @@ -309,42 +309,68 @@ func (x *Int) Cmp(y *Int) (r int) { func (x *Int) String() string { - s := "" - if x.neg { - s = "-" + switch { + case x == nil: + return "<nil>" + case x.neg: + return "-" + x.abs.decimalString() } - return s + x.abs.string(10) + return x.abs.decimalString() } -func fmtbase(ch int) int { +func charset(ch int) string { switch ch { case 'b': - return 2 + return lowercaseDigits[0:2] case 'o': - return 8 - case 'd': - return 10 + return lowercaseDigits[0:8] + case 'd', 'v': + return lowercaseDigits[0:10] case 'x': - return 16 + return lowercaseDigits[0:16] + case 'X': + return uppercaseDigits[0:16] } - return 10 + return "" // unknown format } // Format is a support routine for fmt.Formatter. It accepts -// the formats 'b' (binary), 'o' (octal), 'd' (decimal) and -// 'x' (hexadecimal). +// the formats 'b' (binary), 'o' (octal), 'd' (decimal), 'x' +// (lowercase hexadecimal), and 'X' (uppercase hexadecimal). // func (x *Int) Format(s fmt.State, ch int) { - if x == nil { + cs := charset(ch) + + // special cases + switch { + case cs == "": + // unknown format + fmt.Fprintf(s, "%%!%c(big.Int=%s)", ch, x.String()) + return + case x == nil: fmt.Fprint(s, "<nil>") return } + + // determine format + format := "%s" + if s.Flag('#') { + switch ch { + case 'o': + format = "0%s" + case 'x': + format = "0x%s" + case 'X': + format = "0X%s" + } + } if x.neg { - fmt.Fprint(s, "-") + format = "-" + format } - fmt.Fprint(s, x.abs.string(fmtbase(ch))) + + fmt.Fprintf(s, format, x.abs.string(cs)) } @@ -560,6 +586,42 @@ func (z *Int) Rsh(x *Int, n uint) *Int { } +// Bit returns the value of the i'th bit of z. That is, it +// returns (z>>i)&1. The bit index i must be >= 0. +func (z *Int) Bit(i int) uint { + if i < 0 { + panic("negative bit index") + } + if z.neg { + t := nat{}.sub(z.abs, natOne) + return t.bit(uint(i)) ^ 1 + } + + return z.abs.bit(uint(i)) +} + + +// SetBit sets the i'th bit of z to bit and returns z. +// That is, if bit is 1 SetBit sets z = x | (1 << i); +// if bit is 0 it sets z = x &^ (1 << i). If bit is not 0 or 1, +// SetBit will panic. +func (z *Int) SetBit(x *Int, i int, b uint) *Int { + if i < 0 { + panic("negative bit index") + } + if x.neg { + t := z.abs.sub(x.abs, natOne) + t = t.setBit(t, uint(i), b^1) + z.abs = t.add(t, natOne) + z.neg = len(z.abs) > 0 + return z + } + z.abs = z.abs.setBit(x.abs, uint(i), b) + z.neg = false + return z +} + + // And sets z = x & y and returns z. func (z *Int) And(x, y *Int) *Int { if x.neg == y.neg { diff --git a/src/pkg/big/int_test.go b/src/pkg/big/int_test.go index 9c19dd5da..595f04956 100755 --- a/src/pkg/big/int_test.go +++ b/src/pkg/big/int_test.go @@ -348,6 +348,55 @@ func TestSetString(t *testing.T) { } +var formatTests = []struct { + input string + format string + output string +}{ + {"<nil>", "%x", "<nil>"}, + {"<nil>", "%#x", "<nil>"}, + {"<nil>", "%#y", "%!y(big.Int=<nil>)"}, + + {"10", "%b", "1010"}, + {"10", "%o", "12"}, + {"10", "%d", "10"}, + {"10", "%v", "10"}, + {"10", "%x", "a"}, + {"10", "%X", "A"}, + {"-10", "%X", "-A"}, + {"10", "%y", "%!y(big.Int=10)"}, + {"-10", "%y", "%!y(big.Int=-10)"}, + + {"10", "%#b", "1010"}, + {"10", "%#o", "012"}, + {"10", "%#d", "10"}, + {"10", "%#v", "10"}, + {"10", "%#x", "0xa"}, + {"10", "%#X", "0XA"}, + {"-10", "%#X", "-0XA"}, + {"10", "%#y", "%!y(big.Int=10)"}, + {"-10", "%#y", "%!y(big.Int=-10)"}, +} + + +func TestFormat(t *testing.T) { + for i, test := range formatTests { + var x *Int + if test.input != "<nil>" { + var ok bool + x, ok = new(Int).SetString(test.input, 0) + if !ok { + t.Errorf("#%d failed reading input %s", i, test.input) + } + } + output := fmt.Sprintf(test.format, x) + if output != test.output { + t.Errorf("#%d got %s; want %s", i, output, test.output) + } + } +} + + // Examples from the Go Language Spec, section "Arithmetic operators" var divisionSignsTests = []struct { x, y int64 @@ -985,6 +1034,152 @@ func testBitFunSelf(t *testing.T, msg string, f bitFun, x, y *Int, exp string) { } +func altBit(x *Int, i int) uint { + z := new(Int).Rsh(x, uint(i)) + z = z.And(z, NewInt(1)) + if z.Cmp(new(Int)) != 0 { + return 1 + } + return 0 +} + + +func altSetBit(z *Int, x *Int, i int, b uint) *Int { + one := NewInt(1) + m := one.Lsh(one, uint(i)) + switch b { + case 1: + return z.Or(x, m) + case 0: + return z.AndNot(x, m) + } + panic("set bit is not 0 or 1") +} + + +func testBitset(t *testing.T, x *Int) { + n := x.BitLen() + z := new(Int).Set(x) + z1 := new(Int).Set(x) + for i := 0; i < n+10; i++ { + old := z.Bit(i) + old1 := altBit(z1, i) + if old != old1 { + t.Errorf("bitset: inconsistent value for Bit(%s, %d), got %v want %v", z1, i, old, old1) + } + z := new(Int).SetBit(z, i, 1) + z1 := altSetBit(new(Int), z1, i, 1) + if z.Bit(i) == 0 { + t.Errorf("bitset: bit %d of %s got 0 want 1", i, x) + } + if z.Cmp(z1) != 0 { + t.Errorf("bitset: inconsistent value after SetBit 1, got %s want %s", z, z1) + } + z.SetBit(z, i, 0) + altSetBit(z1, z1, i, 0) + if z.Bit(i) != 0 { + t.Errorf("bitset: bit %d of %s got 1 want 0", i, x) + } + if z.Cmp(z1) != 0 { + t.Errorf("bitset: inconsistent value after SetBit 0, got %s want %s", z, z1) + } + altSetBit(z1, z1, i, old) + z.SetBit(z, i, old) + if z.Cmp(z1) != 0 { + t.Errorf("bitset: inconsistent value after SetBit old, got %s want %s", z, z1) + } + } + if z.Cmp(x) != 0 { + t.Errorf("bitset: got %s want %s", z, x) + } +} + + +var bitsetTests = []struct { + x string + i int + b uint +}{ + {"0", 0, 0}, + {"0", 200, 0}, + {"1", 0, 1}, + {"1", 1, 0}, + {"-1", 0, 1}, + {"-1", 200, 1}, + {"0x2000000000000000000000000000", 108, 0}, + {"0x2000000000000000000000000000", 109, 1}, + {"0x2000000000000000000000000000", 110, 0}, + {"-0x2000000000000000000000000001", 108, 1}, + {"-0x2000000000000000000000000001", 109, 0}, + {"-0x2000000000000000000000000001", 110, 1}, +} + + +func TestBitSet(t *testing.T) { + for _, test := range bitwiseTests { + x := new(Int) + x.SetString(test.x, 0) + testBitset(t, x) + x = new(Int) + x.SetString(test.y, 0) + testBitset(t, x) + } + for i, test := range bitsetTests { + x := new(Int) + x.SetString(test.x, 0) + b := x.Bit(test.i) + if b != test.b { + + t.Errorf("#%d want %v got %v", i, test.b, b) + } + } +} + + +func BenchmarkBitset(b *testing.B) { + z := new(Int) + z.SetBit(z, 512, 1) + b.ResetTimer() + b.StartTimer() + for i := b.N - 1; i >= 0; i-- { + z.SetBit(z, i&512, 1) + } +} + + +func BenchmarkBitsetNeg(b *testing.B) { + z := NewInt(-1) + z.SetBit(z, 512, 0) + b.ResetTimer() + b.StartTimer() + for i := b.N - 1; i >= 0; i-- { + z.SetBit(z, i&512, 0) + } +} + + +func BenchmarkBitsetOrig(b *testing.B) { + z := new(Int) + altSetBit(z, z, 512, 1) + b.ResetTimer() + b.StartTimer() + for i := b.N - 1; i >= 0; i-- { + altSetBit(z, z, i&512, 1) + } +} + + +func BenchmarkBitsetNegOrig(b *testing.B) { + z := NewInt(-1) + altSetBit(z, z, 512, 0) + b.ResetTimer() + b.StartTimer() + for i := b.N - 1; i >= 0; i-- { + altSetBit(z, z, i&512, 0) + } +} + + func TestBitwise(t *testing.T) { x := new(Int) y := new(Int) @@ -1019,6 +1214,7 @@ var notTests = []struct { }, } + func TestNot(t *testing.T) { in := new(Int) out := new(Int) @@ -1047,6 +1243,7 @@ var modInverseTests = []struct { {"239487239847", "2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919"}, } + func TestModInverse(t *testing.T) { var element, prime Int one := NewInt(1) diff --git a/src/pkg/big/nat.go b/src/pkg/big/nat.go index 4848d427b..c2b95e8a2 100755 --- a/src/pkg/big/nat.go +++ b/src/pkg/big/nat.go @@ -20,6 +20,7 @@ package big import "rand" + // An unsigned integer x of the form // // x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0] @@ -668,16 +669,24 @@ func (z nat) scan(s string, base int) (nat, int, int) { } -// string converts x to a string for a given base, with 2 <= base <= 16. -// TODO(gri) in the style of the other routines, perhaps this should take -// a []byte buffer and return it -func (x nat) string(base int) string { - if base < 2 || 16 < base { - panic("illegal base") - } +// Character sets for string conversion. +const ( + lowercaseDigits = "0123456789abcdefghijklmnopqrstuvwxyz" + uppercaseDigits = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" +) + +// string converts x to a string using digits from a charset; a digit with +// value d is represented by charset[d]. The conversion base is determined +// by len(charset), which must be >= 2. +func (x nat) string(charset string) string { + base := len(charset) - if len(x) == 0 { - return "0" + // special cases + switch { + case base < 2: + panic("illegal base") + case len(x) == 0: + return string(charset[0]) } // allocate buffer for conversion @@ -692,13 +701,20 @@ func (x nat) string(base int) string { i-- var r Word q, r = q.divW(q, Word(base)) - s[i] = "0123456789abcdef"[r] + s[i] = charset[r] } return string(s[i:]) } +// decimalString returns a decimal representation of x. +// It calls x.string with the charset "0123456789". +func (x nat) decimalString() string { + return x.string(lowercaseDigits[0:10]) +} + + const deBruijn32 = 0x077CB531 var deBruijn32Lookup = []byte{ @@ -721,7 +737,7 @@ var deBruijn64Lookup = []byte{ func trailingZeroBits(x Word) int { // x & -x leaves only the right-most bit set in the word. Let k be the // index of that bit. Since only a single bit is set, the value is two - // to the power of k. Multipling by a power of two is equivalent to + // to the power of k. Multiplying by a power of two is equivalent to // left shifting, in this case by k bits. The de Bruijn constant is // such that all six bit, consecutive substrings are distinct. // Therefore, if we have a left shifted version of this constant we can @@ -773,6 +789,43 @@ func (z nat) shr(x nat, s uint) nat { } +func (z nat) setBit(x nat, i uint, b uint) nat { + j := int(i / _W) + m := Word(1) << (i % _W) + n := len(x) + switch b { + case 0: + z = z.make(n) + copy(z, x) + if j >= n { + // no need to grow + return z + } + z[j] &^= m + return z.norm() + case 1: + if j >= n { + n = j + 1 + } + z = z.make(n) + copy(z, x) + z[j] |= m + // no need to normalize + return z + } + panic("set bit is not 0 or 1") +} + + +func (z nat) bit(i uint) uint { + j := int(i / _W) + if j >= len(z) { + return 0 + } + return uint(z[j] >> (i % _W) & 1) +} + + func (z nat) and(x, y nat) nat { m := len(x) n := len(y) diff --git a/src/pkg/big/nat_test.go b/src/pkg/big/nat_test.go index 0bcb94554..a29843a3f 100755 --- a/src/pkg/big/nat_test.go +++ b/src/pkg/big/nat_test.go @@ -133,7 +133,7 @@ var mulRangesN = []struct { func TestMulRangeN(t *testing.T) { for i, r := range mulRangesN { - prod := nat(nil).mulRange(r.a, r.b).string(10) + prod := nat(nil).mulRange(r.a, r.b).decimalString() if prod != r.prod { t.Errorf("#%d: got %s; want %s", i, prod, r.prod) } @@ -167,31 +167,35 @@ func BenchmarkMul(b *testing.B) { } -var tab = []struct { - x nat - b int - s string +var strTests = []struct { + x nat // nat value to be converted + c string // conversion charset + s string // expected result }{ - {nil, 10, "0"}, - {nat{1}, 10, "1"}, - {nat{10}, 10, "10"}, - {nat{1234567890}, 10, "1234567890"}, + {nil, "01", "0"}, + {nat{1}, "01", "1"}, + {nat{0xc5}, "01", "11000101"}, + {nat{03271}, lowercaseDigits[0:8], "3271"}, + {nat{10}, lowercaseDigits[0:10], "10"}, + {nat{1234567890}, uppercaseDigits[0:10], "1234567890"}, + {nat{0xdeadbeef}, lowercaseDigits[0:16], "deadbeef"}, + {nat{0xdeadbeef}, uppercaseDigits[0:16], "DEADBEEF"}, } func TestString(t *testing.T) { - for _, a := range tab { - s := a.x.string(a.b) + for _, a := range strTests { + s := a.x.string(a.c) if s != a.s { t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s) } - x, b, n := nat(nil).scan(a.s, a.b) + x, b, n := nat(nil).scan(a.s, len(a.c)) if x.cmp(a.x) != 0 { t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x) } - if b != a.b { - t.Errorf("scan%+v\n\tgot b = %d; want %d", a, b, a.b) + if b != len(a.c) { + t.Errorf("scan%+v\n\tgot b = %d; want %d", a, b, len(a.c)) } if n != len(a.s) { t.Errorf("scan%+v\n\tgot n = %d; want %d", a, n, len(a.s)) diff --git a/src/pkg/big/rat.go b/src/pkg/big/rat.go index e70673a1c..2adf316e6 100644 --- a/src/pkg/big/rat.go +++ b/src/pkg/big/rat.go @@ -84,7 +84,7 @@ func (z *Rat) Num() *Int { } -// Demom returns the denominator of z; it is always > 0. +// Denom returns the denominator of z; it is always > 0. // The result is a reference to z's denominator; it // may change if a new value is assigned to z. func (z *Rat) Denom() *Int { @@ -270,7 +270,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) { // String returns a string representation of z in the form "a/b" (even if b == 1). func (z *Rat) String() string { - return z.a.String() + "/" + z.b.string(10) + return z.a.String() + "/" + z.b.decimalString() } @@ -311,13 +311,13 @@ func (z *Rat) FloatString(prec int) string { } } - s := q.string(10) + s := q.decimalString() if z.a.neg { s = "-" + s } if prec > 0 { - rs := r.string(10) + rs := r.decimalString() leadingZeros := prec - len(rs) s += "." + strings.Repeat("0", leadingZeros) + rs } |