summaryrefslogtreecommitdiff
path: root/src/pkg/math/big
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/math/big')
-rw-r--r--src/pkg/math/big/arith_386.s18
-rw-r--r--src/pkg/math/big/arith_amd64.s331
-rw-r--r--src/pkg/math/big/arith_arm.s18
-rw-r--r--src/pkg/math/big/arith_test.go86
-rw-r--r--src/pkg/math/big/calibrate_test.go42
-rw-r--r--src/pkg/math/big/gcd_test.go47
-rw-r--r--src/pkg/math/big/int.go124
-rw-r--r--src/pkg/math/big/int_test.go205
-rw-r--r--src/pkg/math/big/nat.go309
-rw-r--r--src/pkg/math/big/nat_test.go135
-rw-r--r--src/pkg/math/big/rat.go269
-rw-r--r--src/pkg/math/big/rat_test.go499
12 files changed, 1656 insertions, 427 deletions
diff --git a/src/pkg/math/big/arith_386.s b/src/pkg/math/big/arith_386.s
index f1262c651..c62483317 100644
--- a/src/pkg/math/big/arith_386.s
+++ b/src/pkg/math/big/arith_386.s
@@ -29,7 +29,7 @@ TEXT ·addVV(SB),7,$0
MOVL z+0(FP), DI
MOVL x+12(FP), SI
MOVL y+24(FP), CX
- MOVL n+4(FP), BP
+ MOVL z+4(FP), BP
MOVL $0, BX // i = 0
MOVL $0, DX // c = 0
JMP E1
@@ -54,7 +54,7 @@ TEXT ·subVV(SB),7,$0
MOVL z+0(FP), DI
MOVL x+12(FP), SI
MOVL y+24(FP), CX
- MOVL n+4(FP), BP
+ MOVL z+4(FP), BP
MOVL $0, BX // i = 0
MOVL $0, DX // c = 0
JMP E2
@@ -78,7 +78,7 @@ TEXT ·addVW(SB),7,$0
MOVL z+0(FP), DI
MOVL x+12(FP), SI
MOVL y+24(FP), AX // c = y
- MOVL n+4(FP), BP
+ MOVL z+4(FP), BP
MOVL $0, BX // i = 0
JMP E3
@@ -100,7 +100,7 @@ TEXT ·subVW(SB),7,$0
MOVL z+0(FP), DI
MOVL x+12(FP), SI
MOVL y+24(FP), AX // c = y
- MOVL n+4(FP), BP
+ MOVL z+4(FP), BP
MOVL $0, BX // i = 0
JMP E4
@@ -120,7 +120,7 @@ E4: CMPL BX, BP // i < n
// func shlVU(z, x []Word, s uint) (c Word)
TEXT ·shlVU(SB),7,$0
- MOVL n+4(FP), BX // i = n
+ MOVL z+4(FP), BX // i = z
SUBL $1, BX // i--
JL X8b // i < 0 (n <= 0)
@@ -155,7 +155,7 @@ X8b: MOVL $0, c+28(FP)
// func shrVU(z, x []Word, s uint) (c Word)
TEXT ·shrVU(SB),7,$0
- MOVL n+4(FP), BP
+ MOVL z+4(FP), BP
SUBL $1, BP // n--
JL X9b // n < 0 (n <= 0)
@@ -196,7 +196,7 @@ TEXT ·mulAddVWW(SB),7,$0
MOVL x+12(FP), SI
MOVL y+24(FP), BP
MOVL r+28(FP), CX // c = r
- MOVL n+4(FP), BX
+ MOVL z+4(FP), BX
LEAL (DI)(BX*4), DI
LEAL (SI)(BX*4), SI
NEGL BX // i = -n
@@ -222,7 +222,7 @@ TEXT ·addMulVVW(SB),7,$0
MOVL z+0(FP), DI
MOVL x+12(FP), SI
MOVL y+24(FP), BP
- MOVL n+4(FP), BX
+ MOVL z+4(FP), BX
LEAL (DI)(BX*4), DI
LEAL (SI)(BX*4), SI
NEGL BX // i = -n
@@ -251,7 +251,7 @@ TEXT ·divWVW(SB),7,$0
MOVL xn+12(FP), DX // r = xn
MOVL x+16(FP), SI
MOVL y+28(FP), CX
- MOVL n+4(FP), BX // i = n
+ MOVL z+4(FP), BX // i = z
JMP E7
L7: MOVL (SI)(BX*4), AX
diff --git a/src/pkg/math/big/arith_amd64.s b/src/pkg/math/big/arith_amd64.s
index 54f647322..d85964502 100644
--- a/src/pkg/math/big/arith_amd64.s
+++ b/src/pkg/math/big/arith_amd64.s
@@ -5,7 +5,15 @@
// This file provides fast assembly versions for the elementary
// arithmetic operations on vectors implemented in arith.go.
-// TODO(gri) - experiment with unrolled loops for faster execution
+// Literal instruction for MOVQ $0, CX.
+// (MOVQ $0, reg is translated to XORQ reg, reg and clears CF.)
+#define ZERO_CX BYTE $0x48; \
+ BYTE $0xc7; \
+ BYTE $0xc1; \
+ BYTE $0x00; \
+ BYTE $0x00; \
+ BYTE $0x00; \
+ BYTE $0x00
// func mulWW(x, y Word) (z1, z0 Word)
TEXT ·mulWW(SB),7,$0
@@ -28,114 +36,231 @@ TEXT ·divWW(SB),7,$0
// func addVV(z, x, y []Word) (c Word)
TEXT ·addVV(SB),7,$0
+ MOVQ z+8(FP), DI
+ MOVQ x+24(FP), R8
+ MOVQ y+48(FP), R9
MOVQ z+0(FP), R10
- MOVQ x+16(FP), R8
- MOVQ y+32(FP), R9
- MOVL n+8(FP), R11
- MOVQ $0, BX // i = 0
- MOVQ $0, DX // c = 0
- JMP E1
-
-L1: MOVQ (R8)(BX*8), AX
- RCRQ $1, DX
- ADCQ (R9)(BX*8), AX
- RCLQ $1, DX
- MOVQ AX, (R10)(BX*8)
- ADDL $1, BX // i++
-E1: CMPQ BX, R11 // i < n
- JL L1
-
- MOVQ DX, c+48(FP)
+ MOVQ $0, CX // c = 0
+ MOVQ $0, SI // i = 0
+
+ // s/JL/JMP/ below to disable the unrolled loop
+ SUBQ $4, DI // n -= 4
+ JL V1 // if n < 0 goto V1
+
+U1: // n >= 0
+ // regular loop body unrolled 4x
+ RCRQ $1, CX // CF = c
+ MOVQ 0(R8)(SI*8), R11
+ MOVQ 8(R8)(SI*8), R12
+ MOVQ 16(R8)(SI*8), R13
+ MOVQ 24(R8)(SI*8), R14
+ ADCQ 0(R9)(SI*8), R11
+ ADCQ 8(R9)(SI*8), R12
+ ADCQ 16(R9)(SI*8), R13
+ ADCQ 24(R9)(SI*8), R14
+ MOVQ R11, 0(R10)(SI*8)
+ MOVQ R12, 8(R10)(SI*8)
+ MOVQ R13, 16(R10)(SI*8)
+ MOVQ R14, 24(R10)(SI*8)
+ RCLQ $1, CX // c = CF
+
+ ADDQ $4, SI // i += 4
+ SUBQ $4, DI // n -= 4
+ JGE U1 // if n >= 0 goto U1
+
+V1: ADDQ $4, DI // n += 4
+ JLE E1 // if n <= 0 goto E1
+
+L1: // n > 0
+ RCRQ $1, CX // CF = c
+ MOVQ 0(R8)(SI*8), R11
+ ADCQ 0(R9)(SI*8), R11
+ MOVQ R11, 0(R10)(SI*8)
+ RCLQ $1, CX // c = CF
+
+ ADDQ $1, SI // i++
+ SUBQ $1, DI // n--
+ JG L1 // if n > 0 goto L1
+
+E1: MOVQ CX, c+72(FP) // return c
RET
// func subVV(z, x, y []Word) (c Word)
-// (same as addVV_s except for SBBQ instead of ADCQ and label names)
+// (same as addVV except for SBBQ instead of ADCQ and label names)
TEXT ·subVV(SB),7,$0
+ MOVQ z+8(FP), DI
+ MOVQ x+24(FP), R8
+ MOVQ y+48(FP), R9
MOVQ z+0(FP), R10
- MOVQ x+16(FP), R8
- MOVQ y+32(FP), R9
- MOVL n+8(FP), R11
- MOVQ $0, BX // i = 0
- MOVQ $0, DX // c = 0
- JMP E2
-
-L2: MOVQ (R8)(BX*8), AX
- RCRQ $1, DX
- SBBQ (R9)(BX*8), AX
- RCLQ $1, DX
- MOVQ AX, (R10)(BX*8)
- ADDL $1, BX // i++
-
-E2: CMPQ BX, R11 // i < n
- JL L2
- MOVQ DX, c+48(FP)
+ MOVQ $0, CX // c = 0
+ MOVQ $0, SI // i = 0
+
+ // s/JL/JMP/ below to disable the unrolled loop
+ SUBQ $4, DI // n -= 4
+ JL V2 // if n < 0 goto V2
+
+U2: // n >= 0
+ // regular loop body unrolled 4x
+ RCRQ $1, CX // CF = c
+ MOVQ 0(R8)(SI*8), R11
+ MOVQ 8(R8)(SI*8), R12
+ MOVQ 16(R8)(SI*8), R13
+ MOVQ 24(R8)(SI*8), R14
+ SBBQ 0(R9)(SI*8), R11
+ SBBQ 8(R9)(SI*8), R12
+ SBBQ 16(R9)(SI*8), R13
+ SBBQ 24(R9)(SI*8), R14
+ MOVQ R11, 0(R10)(SI*8)
+ MOVQ R12, 8(R10)(SI*8)
+ MOVQ R13, 16(R10)(SI*8)
+ MOVQ R14, 24(R10)(SI*8)
+ RCLQ $1, CX // c = CF
+
+ ADDQ $4, SI // i += 4
+ SUBQ $4, DI // n -= 4
+ JGE U2 // if n >= 0 goto U2
+
+V2: ADDQ $4, DI // n += 4
+ JLE E2 // if n <= 0 goto E2
+
+L2: // n > 0
+ RCRQ $1, CX // CF = c
+ MOVQ 0(R8)(SI*8), R11
+ SBBQ 0(R9)(SI*8), R11
+ MOVQ R11, 0(R10)(SI*8)
+ RCLQ $1, CX // c = CF
+
+ ADDQ $1, SI // i++
+ SUBQ $1, DI // n--
+ JG L2 // if n > 0 goto L2
+
+E2: MOVQ CX, c+72(FP) // return c
RET
// func addVW(z, x []Word, y Word) (c Word)
TEXT ·addVW(SB),7,$0
+ MOVQ z+8(FP), DI
+ MOVQ x+24(FP), R8
+ MOVQ y+48(FP), CX // c = y
MOVQ z+0(FP), R10
- MOVQ x+16(FP), R8
- MOVQ y+32(FP), AX // c = y
- MOVL n+8(FP), R11
- MOVQ $0, BX // i = 0
- JMP E3
-L3: ADDQ (R8)(BX*8), AX
- MOVQ AX, (R10)(BX*8)
- RCLQ $1, AX
- ANDQ $1, AX
- ADDL $1, BX // i++
-
-E3: CMPQ BX, R11 // i < n
- JL L3
-
- MOVQ AX, c+40(FP)
+ MOVQ $0, SI // i = 0
+
+ // s/JL/JMP/ below to disable the unrolled loop
+ SUBQ $4, DI // n -= 4
+ JL V3 // if n < 4 goto V3
+
+U3: // n >= 0
+ // regular loop body unrolled 4x
+ MOVQ 0(R8)(SI*8), R11
+ MOVQ 8(R8)(SI*8), R12
+ MOVQ 16(R8)(SI*8), R13
+ MOVQ 24(R8)(SI*8), R14
+ ADDQ CX, R11
+ ZERO_CX
+ ADCQ $0, R12
+ ADCQ $0, R13
+ ADCQ $0, R14
+ SETCS CX // c = CF
+ MOVQ R11, 0(R10)(SI*8)
+ MOVQ R12, 8(R10)(SI*8)
+ MOVQ R13, 16(R10)(SI*8)
+ MOVQ R14, 24(R10)(SI*8)
+
+ ADDQ $4, SI // i += 4
+ SUBQ $4, DI // n -= 4
+ JGE U3 // if n >= 0 goto U3
+
+V3: ADDQ $4, DI // n += 4
+ JLE E3 // if n <= 0 goto E3
+
+L3: // n > 0
+ ADDQ 0(R8)(SI*8), CX
+ MOVQ CX, 0(R10)(SI*8)
+ ZERO_CX
+ RCLQ $1, CX // c = CF
+
+ ADDQ $1, SI // i++
+ SUBQ $1, DI // n--
+ JG L3 // if n > 0 goto L3
+
+E3: MOVQ CX, c+56(FP) // return c
RET
// func subVW(z, x []Word, y Word) (c Word)
+// (same as addVW except for SUBQ/SBBQ instead of ADDQ/ADCQ and label names)
TEXT ·subVW(SB),7,$0
+ MOVQ z+8(FP), DI
+ MOVQ x+24(FP), R8
+ MOVQ y+48(FP), CX // c = y
MOVQ z+0(FP), R10
- MOVQ x+16(FP), R8
- MOVQ y+32(FP), AX // c = y
- MOVL n+8(FP), R11
- MOVQ $0, BX // i = 0
- JMP E4
-
-L4: MOVQ (R8)(BX*8), DX // TODO(gri) is there a reverse SUBQ?
- SUBQ AX, DX
- MOVQ DX, (R10)(BX*8)
- RCLQ $1, AX
- ANDQ $1, AX
- ADDL $1, BX // i++
-
-E4: CMPQ BX, R11 // i < n
- JL L4
-
- MOVQ AX, c+40(FP)
+
+ MOVQ $0, SI // i = 0
+
+ // s/JL/JMP/ below to disable the unrolled loop
+ SUBQ $4, DI // n -= 4
+ JL V4 // if n < 4 goto V4
+
+U4: // n >= 0
+ // regular loop body unrolled 4x
+ MOVQ 0(R8)(SI*8), R11
+ MOVQ 8(R8)(SI*8), R12
+ MOVQ 16(R8)(SI*8), R13
+ MOVQ 24(R8)(SI*8), R14
+ SUBQ CX, R11
+ ZERO_CX
+ SBBQ $0, R12
+ SBBQ $0, R13
+ SBBQ $0, R14
+ SETCS CX // c = CF
+ MOVQ R11, 0(R10)(SI*8)
+ MOVQ R12, 8(R10)(SI*8)
+ MOVQ R13, 16(R10)(SI*8)
+ MOVQ R14, 24(R10)(SI*8)
+
+ ADDQ $4, SI // i += 4
+ SUBQ $4, DI // n -= 4
+ JGE U4 // if n >= 0 goto U4
+
+V4: ADDQ $4, DI // n += 4
+ JLE E4 // if n <= 0 goto E4
+
+L4: // n > 0
+ MOVQ 0(R8)(SI*8), R11
+ SUBQ CX, R11
+ MOVQ R11, 0(R10)(SI*8)
+ ZERO_CX
+ RCLQ $1, CX // c = CF
+
+ ADDQ $1, SI // i++
+ SUBQ $1, DI // n--
+ JG L4 // if n > 0 goto L4
+
+E4: MOVQ CX, c+56(FP) // return c
RET
// func shlVU(z, x []Word, s uint) (c Word)
TEXT ·shlVU(SB),7,$0
- MOVL n+8(FP), BX // i = n
- SUBL $1, BX // i--
+ MOVQ z+8(FP), BX // i = z
+ SUBQ $1, BX // i--
JL X8b // i < 0 (n <= 0)
// n > 0
MOVQ z+0(FP), R10
- MOVQ x+16(FP), R8
- MOVL s+32(FP), CX
+ MOVQ x+24(FP), R8
+ MOVQ s+48(FP), CX
MOVQ (R8)(BX*8), AX // w1 = x[n-1]
MOVQ $0, DX
SHLQ CX, DX:AX // w1>>ŝ
- MOVQ DX, c+40(FP)
+ MOVQ DX, c+56(FP)
- CMPL BX, $0
+ CMPQ BX, $0
JLE X8a // i <= 0
// i > 0
@@ -143,7 +268,7 @@ L8: MOVQ AX, DX // w = w1
MOVQ -8(R8)(BX*8), AX // w1 = x[i-1]
SHLQ CX, DX:AX // w<<s | w1>>ŝ
MOVQ DX, (R10)(BX*8) // z[i] = w<<s | w1>>ŝ
- SUBL $1, BX // i--
+ SUBQ $1, BX // i--
JG L8 // i > 0
// i <= 0
@@ -151,24 +276,24 @@ X8a: SHLQ CX, AX // w1<<s
MOVQ AX, (R10) // z[0] = w1<<s
RET
-X8b: MOVQ $0, c+40(FP)
+X8b: MOVQ $0, c+56(FP)
RET
// func shrVU(z, x []Word, s uint) (c Word)
TEXT ·shrVU(SB),7,$0
- MOVL n+8(FP), R11
- SUBL $1, R11 // n--
+ MOVQ z+8(FP), R11
+ SUBQ $1, R11 // n--
JL X9b // n < 0 (n <= 0)
// n > 0
MOVQ z+0(FP), R10
- MOVQ x+16(FP), R8
- MOVL s+32(FP), CX
+ MOVQ x+24(FP), R8
+ MOVQ s+48(FP), CX
MOVQ (R8), AX // w1 = x[0]
MOVQ $0, DX
SHRQ CX, DX:AX // w1<<ŝ
- MOVQ DX, c+40(FP)
+ MOVQ DX, c+56(FP)
MOVQ $0, BX // i = 0
JMP E9
@@ -178,7 +303,7 @@ L9: MOVQ AX, DX // w = w1
MOVQ 8(R8)(BX*8), AX // w1 = x[i+1]
SHRQ CX, DX:AX // w>>s | w1<<ŝ
MOVQ DX, (R10)(BX*8) // z[i] = w>>s | w1<<ŝ
- ADDL $1, BX // i++
+ ADDQ $1, BX // i++
E9: CMPQ BX, R11
JL L9 // i < n-1
@@ -188,17 +313,17 @@ X9a: SHRQ CX, AX // w1>>s
MOVQ AX, (R10)(R11*8) // z[n-1] = w1>>s
RET
-X9b: MOVQ $0, c+40(FP)
+X9b: MOVQ $0, c+56(FP)
RET
// func mulAddVWW(z, x []Word, y, r Word) (c Word)
TEXT ·mulAddVWW(SB),7,$0
MOVQ z+0(FP), R10
- MOVQ x+16(FP), R8
- MOVQ y+32(FP), R9
- MOVQ r+40(FP), CX // c = r
- MOVL n+8(FP), R11
+ MOVQ x+24(FP), R8
+ MOVQ y+48(FP), R9
+ MOVQ r+56(FP), CX // c = r
+ MOVQ z+8(FP), R11
MOVQ $0, BX // i = 0
JMP E5
@@ -208,21 +333,21 @@ L5: MOVQ (R8)(BX*8), AX
ADCQ $0, DX
MOVQ AX, (R10)(BX*8)
MOVQ DX, CX
- ADDL $1, BX // i++
+ ADDQ $1, BX // i++
E5: CMPQ BX, R11 // i < n
JL L5
- MOVQ CX, c+48(FP)
+ MOVQ CX, c+64(FP)
RET
// func addMulVVW(z, x []Word, y Word) (c Word)
TEXT ·addMulVVW(SB),7,$0
MOVQ z+0(FP), R10
- MOVQ x+16(FP), R8
- MOVQ y+32(FP), R9
- MOVL n+8(FP), R11
+ MOVQ x+24(FP), R8
+ MOVQ y+48(FP), R9
+ MOVQ z+8(FP), R11
MOVQ $0, BX // i = 0
MOVQ $0, CX // c = 0
JMP E6
@@ -234,41 +359,41 @@ L6: MOVQ (R8)(BX*8), AX
ADDQ AX, (R10)(BX*8)
ADCQ $0, DX
MOVQ DX, CX
- ADDL $1, BX // i++
+ ADDQ $1, BX // i++
E6: CMPQ BX, R11 // i < n
JL L6
- MOVQ CX, c+40(FP)
+ MOVQ CX, c+56(FP)
RET
// func divWVW(z []Word, xn Word, x []Word, y Word) (r Word)
TEXT ·divWVW(SB),7,$0
MOVQ z+0(FP), R10
- MOVQ xn+16(FP), DX // r = xn
- MOVQ x+24(FP), R8
- MOVQ y+40(FP), R9
- MOVL n+8(FP), BX // i = n
+ MOVQ xn+24(FP), DX // r = xn
+ MOVQ x+32(FP), R8
+ MOVQ y+56(FP), R9
+ MOVQ z+8(FP), BX // i = z
JMP E7
L7: MOVQ (R8)(BX*8), AX
DIVQ R9
MOVQ AX, (R10)(BX*8)
-E7: SUBL $1, BX // i--
+E7: SUBQ $1, BX // i--
JGE L7 // i >= 0
- MOVQ DX, r+48(FP)
+ MOVQ DX, r+64(FP)
RET
// func bitLen(x Word) (n int)
TEXT ·bitLen(SB),7,$0
BSRQ x+0(FP), AX
JZ Z1
- INCL AX
- MOVL AX, n+8(FP)
+ ADDQ $1, AX
+ MOVQ AX, n+8(FP)
RET
-Z1: MOVL $0, n+8(FP)
+Z1: MOVQ $0, n+8(FP)
RET
diff --git a/src/pkg/math/big/arith_arm.s b/src/pkg/math/big/arith_arm.s
index dbf3360b5..64610f915 100644
--- a/src/pkg/math/big/arith_arm.s
+++ b/src/pkg/math/big/arith_arm.s
@@ -13,7 +13,7 @@ TEXT ·addVV(SB),7,$0
MOVW z+0(FP), R1
MOVW x+12(FP), R2
MOVW y+24(FP), R3
- MOVW n+4(FP), R4
+ MOVW z+4(FP), R4
MOVW R4<<2, R4
ADD R1, R4
B E1
@@ -41,7 +41,7 @@ TEXT ·subVV(SB),7,$0
MOVW z+0(FP), R1
MOVW x+12(FP), R2
MOVW y+24(FP), R3
- MOVW n+4(FP), R4
+ MOVW z+4(FP), R4
MOVW R4<<2, R4
ADD R1, R4
B E2
@@ -68,7 +68,7 @@ TEXT ·addVW(SB),7,$0
MOVW z+0(FP), R1
MOVW x+12(FP), R2
MOVW y+24(FP), R3
- MOVW n+4(FP), R4
+ MOVW z+4(FP), R4
MOVW R4<<2, R4
ADD R1, R4
CMP R1, R4
@@ -102,7 +102,7 @@ TEXT ·subVW(SB),7,$0
MOVW z+0(FP), R1
MOVW x+12(FP), R2
MOVW y+24(FP), R3
- MOVW n+4(FP), R4
+ MOVW z+4(FP), R4
MOVW R4<<2, R4
ADD R1, R4
CMP R1, R4
@@ -134,7 +134,7 @@ E4:
// func shlVU(z, x []Word, s uint) (c Word)
TEXT ·shlVU(SB),7,$0
- MOVW n+4(FP), R5
+ MOVW z+4(FP), R5
CMP $0, R5
BEQ X7
@@ -183,7 +183,7 @@ X7:
// func shrVU(z, x []Word, s uint) (c Word)
TEXT ·shrVU(SB),7,$0
- MOVW n+4(FP), R5
+ MOVW z+4(FP), R5
CMP $0, R5
BEQ X6
@@ -238,7 +238,7 @@ TEXT ·mulAddVWW(SB),7,$0
MOVW x+12(FP), R2
MOVW y+24(FP), R3
MOVW r+28(FP), R4
- MOVW n+4(FP), R5
+ MOVW z+4(FP), R5
MOVW R5<<2, R5
ADD R1, R5
B E8
@@ -265,7 +265,7 @@ TEXT ·addMulVVW(SB),7,$0
MOVW z+0(FP), R1
MOVW x+12(FP), R2
MOVW y+24(FP), R3
- MOVW n+4(FP), R5
+ MOVW z+4(FP), R5
MOVW R5<<2, R5
ADD R1, R5
MOVW $0, R4
@@ -314,7 +314,7 @@ TEXT ·mulWW(SB),7,$0
// func bitLen(x Word) (n int)
TEXT ·bitLen(SB),7,$0
MOVW x+0(FP), R0
- WORD $0xe16f0f10 // CLZ R0, R0 (count leading zeros)
+ CLZ R0, R0
MOVW $32, R1
SUB.S R0, R1
MOVW R1, n+4(FP)
diff --git a/src/pkg/math/big/arith_test.go b/src/pkg/math/big/arith_test.go
index c7e3d284c..3615a659c 100644
--- a/src/pkg/math/big/arith_test.go
+++ b/src/pkg/math/big/arith_test.go
@@ -4,7 +4,10 @@
package big
-import "testing"
+import (
+ "math/rand"
+ "testing"
+)
type funWW func(x, y, c Word) (z1, z0 Word)
type argWW struct {
@@ -100,6 +103,43 @@ func TestFunVV(t *testing.T) {
}
}
+// Always the same seed for reproducible results.
+var rnd = rand.New(rand.NewSource(0))
+
+func rndW() Word {
+ return Word(rnd.Int63()<<1 | rnd.Int63n(2))
+}
+
+func rndV(n int) []Word {
+ v := make([]Word, n)
+ for i := range v {
+ v[i] = rndW()
+ }
+ return v
+}
+
+func benchmarkFunVV(b *testing.B, f funVV, n int) {
+ x := rndV(n)
+ y := rndV(n)
+ z := make([]Word, n)
+ b.SetBytes(int64(n * _W))
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ f(z, x, y)
+ }
+}
+
+func BenchmarkAddVV_1(b *testing.B) { benchmarkFunVV(b, addVV, 1) }
+func BenchmarkAddVV_2(b *testing.B) { benchmarkFunVV(b, addVV, 2) }
+func BenchmarkAddVV_3(b *testing.B) { benchmarkFunVV(b, addVV, 3) }
+func BenchmarkAddVV_4(b *testing.B) { benchmarkFunVV(b, addVV, 4) }
+func BenchmarkAddVV_5(b *testing.B) { benchmarkFunVV(b, addVV, 5) }
+func BenchmarkAddVV_1e1(b *testing.B) { benchmarkFunVV(b, addVV, 1e1) }
+func BenchmarkAddVV_1e2(b *testing.B) { benchmarkFunVV(b, addVV, 1e2) }
+func BenchmarkAddVV_1e3(b *testing.B) { benchmarkFunVV(b, addVV, 1e3) }
+func BenchmarkAddVV_1e4(b *testing.B) { benchmarkFunVV(b, addVV, 1e4) }
+func BenchmarkAddVV_1e5(b *testing.B) { benchmarkFunVV(b, addVV, 1e5) }
+
type funVW func(z, x []Word, y Word) (c Word)
type argVW struct {
z, x nat
@@ -210,6 +250,28 @@ func TestFunVW(t *testing.T) {
}
}
+func benchmarkFunVW(b *testing.B, f funVW, n int) {
+ x := rndV(n)
+ y := rndW()
+ z := make([]Word, n)
+ b.SetBytes(int64(n * _W))
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ f(z, x, y)
+ }
+}
+
+func BenchmarkAddVW_1(b *testing.B) { benchmarkFunVW(b, addVW, 1) }
+func BenchmarkAddVW_2(b *testing.B) { benchmarkFunVW(b, addVW, 2) }
+func BenchmarkAddVW_3(b *testing.B) { benchmarkFunVW(b, addVW, 3) }
+func BenchmarkAddVW_4(b *testing.B) { benchmarkFunVW(b, addVW, 4) }
+func BenchmarkAddVW_5(b *testing.B) { benchmarkFunVW(b, addVW, 5) }
+func BenchmarkAddVW_1e1(b *testing.B) { benchmarkFunVW(b, addVW, 1e1) }
+func BenchmarkAddVW_1e2(b *testing.B) { benchmarkFunVW(b, addVW, 1e2) }
+func BenchmarkAddVW_1e3(b *testing.B) { benchmarkFunVW(b, addVW, 1e3) }
+func BenchmarkAddVW_1e4(b *testing.B) { benchmarkFunVW(b, addVW, 1e4) }
+func BenchmarkAddVW_1e5(b *testing.B) { benchmarkFunVW(b, addVW, 1e5) }
+
type funVWW func(z, x []Word, y, r Word) (c Word)
type argVWW struct {
z, x nat
@@ -334,6 +396,28 @@ func TestMulAddWWW(t *testing.T) {
}
}
+func benchmarkAddMulVVW(b *testing.B, n int) {
+ x := rndV(n)
+ y := rndW()
+ z := make([]Word, n)
+ b.SetBytes(int64(n * _W))
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ addMulVVW(z, x, y)
+ }
+}
+
+func BenchmarkAddMulVVW_1(b *testing.B) { benchmarkAddMulVVW(b, 1) }
+func BenchmarkAddMulVVW_2(b *testing.B) { benchmarkAddMulVVW(b, 2) }
+func BenchmarkAddMulVVW_3(b *testing.B) { benchmarkAddMulVVW(b, 3) }
+func BenchmarkAddMulVVW_4(b *testing.B) { benchmarkAddMulVVW(b, 4) }
+func BenchmarkAddMulVVW_5(b *testing.B) { benchmarkAddMulVVW(b, 5) }
+func BenchmarkAddMulVVW_1e1(b *testing.B) { benchmarkAddMulVVW(b, 1e1) }
+func BenchmarkAddMulVVW_1e2(b *testing.B) { benchmarkAddMulVVW(b, 1e2) }
+func BenchmarkAddMulVVW_1e3(b *testing.B) { benchmarkAddMulVVW(b, 1e3) }
+func BenchmarkAddMulVVW_1e4(b *testing.B) { benchmarkAddMulVVW(b, 1e4) }
+func BenchmarkAddMulVVW_1e5(b *testing.B) { benchmarkAddMulVVW(b, 1e5) }
+
func testWordBitLen(t *testing.T, fname string, f func(Word) int) {
for i := 0; i <= _W; i++ {
x := Word(1) << uint(i-1) // i == 0 => x == 0
diff --git a/src/pkg/math/big/calibrate_test.go b/src/pkg/math/big/calibrate_test.go
index efe1837bb..f69ffbf5c 100644
--- a/src/pkg/math/big/calibrate_test.go
+++ b/src/pkg/math/big/calibrate_test.go
@@ -21,15 +21,17 @@ import (
var calibrate = flag.Bool("calibrate", false, "run calibration test")
-// measure returns the time to run f
-func measure(f func()) time.Duration {
- const N = 100
- start := time.Now()
- for i := N; i > 0; i-- {
- f()
- }
- stop := time.Now()
- return stop.Sub(start) / N
+func karatsubaLoad(b *testing.B) {
+ BenchmarkMul(b)
+}
+
+// measureKaratsuba returns the time to run a Karatsuba-relevant benchmark
+// given Karatsuba threshold th.
+func measureKaratsuba(th int) time.Duration {
+ th, karatsubaThreshold = karatsubaThreshold, th
+ res := testing.Benchmark(karatsubaLoad)
+ karatsubaThreshold = th
+ return time.Duration(res.NsPerOp())
}
func computeThresholds() {
@@ -37,35 +39,33 @@ func computeThresholds() {
fmt.Printf("(run repeatedly for good results)\n")
// determine Tk, the work load execution time using basic multiplication
- karatsubaThreshold = 1e9 // disable karatsuba
- Tb := measure(benchmarkMulLoad)
- fmt.Printf("Tb = %dns\n", Tb)
+ Tb := measureKaratsuba(1e9) // th == 1e9 => Karatsuba multiplication disabled
+ fmt.Printf("Tb = %10s\n", Tb)
// thresholds
- n := 8 // any lower values for the threshold lead to very slow multiplies
+ th := 4
th1 := -1
th2 := -1
var deltaOld time.Duration
- for count := -1; count != 0; count-- {
+ for count := -1; count != 0 && th < 128; count-- {
// determine Tk, the work load execution time using Karatsuba multiplication
- karatsubaThreshold = n // enable karatsuba
- Tk := measure(benchmarkMulLoad)
+ Tk := measureKaratsuba(th)
// improvement over Tb
delta := (Tb - Tk) * 100 / Tb
- fmt.Printf("n = %3d Tk = %8dns %4d%%", n, Tk, delta)
+ fmt.Printf("th = %3d Tk = %10s %4d%%", th, Tk, delta)
// determine break-even point
if Tk < Tb && th1 < 0 {
- th1 = n
+ th1 = th
fmt.Print(" break-even point")
}
// determine diminishing return
if 0 < delta && delta < deltaOld && th2 < 0 {
- th2 = n
+ th2 = th
fmt.Print(" diminishing return")
}
deltaOld = delta
@@ -74,10 +74,10 @@ func computeThresholds() {
// trigger counter
if th1 >= 0 && th2 >= 0 && count < 0 {
- count = 20 // this many extra measurements after we got both thresholds
+ count = 10 // this many extra measurements after we got both thresholds
}
- n++
+ th++
}
}
diff --git a/src/pkg/math/big/gcd_test.go b/src/pkg/math/big/gcd_test.go
new file mode 100644
index 000000000..c0b9f5830
--- /dev/null
+++ b/src/pkg/math/big/gcd_test.go
@@ -0,0 +1,47 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements a GCD benchmark.
+// Usage: go test math/big -test.bench GCD
+
+package big
+
+import (
+ "math/rand"
+ "testing"
+)
+
+// randInt returns a pseudo-random Int in the range [1<<(size-1), (1<<size) - 1]
+func randInt(r *rand.Rand, size uint) *Int {
+ n := new(Int).Lsh(intOne, size-1)
+ x := new(Int).Rand(r, n)
+ return x.Add(x, n) // make sure result > 1<<(size-1)
+}
+
+func runGCD(b *testing.B, aSize, bSize uint) {
+ b.StopTimer()
+ var r = rand.New(rand.NewSource(1234))
+ aa := randInt(r, aSize)
+ bb := randInt(r, bSize)
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ new(Int).GCD(nil, nil, aa, bb)
+ }
+}
+
+func BenchmarkGCD10x10(b *testing.B) { runGCD(b, 10, 10) }
+func BenchmarkGCD10x100(b *testing.B) { runGCD(b, 10, 100) }
+func BenchmarkGCD10x1000(b *testing.B) { runGCD(b, 10, 1000) }
+func BenchmarkGCD10x10000(b *testing.B) { runGCD(b, 10, 10000) }
+func BenchmarkGCD10x100000(b *testing.B) { runGCD(b, 10, 100000) }
+func BenchmarkGCD100x100(b *testing.B) { runGCD(b, 100, 100) }
+func BenchmarkGCD100x1000(b *testing.B) { runGCD(b, 100, 1000) }
+func BenchmarkGCD100x10000(b *testing.B) { runGCD(b, 100, 10000) }
+func BenchmarkGCD100x100000(b *testing.B) { runGCD(b, 100, 100000) }
+func BenchmarkGCD1000x1000(b *testing.B) { runGCD(b, 1000, 1000) }
+func BenchmarkGCD1000x10000(b *testing.B) { runGCD(b, 1000, 10000) }
+func BenchmarkGCD1000x100000(b *testing.B) { runGCD(b, 1000, 100000) }
+func BenchmarkGCD10000x10000(b *testing.B) { runGCD(b, 10000, 10000) }
+func BenchmarkGCD10000x100000(b *testing.B) { runGCD(b, 10000, 100000) }
+func BenchmarkGCD100000x100000(b *testing.B) { runGCD(b, 100000, 100000) }
diff --git a/src/pkg/math/big/int.go b/src/pkg/math/big/int.go
index cd2cd0e2d..bf2fd2009 100644
--- a/src/pkg/math/big/int.go
+++ b/src/pkg/math/big/int.go
@@ -51,6 +51,13 @@ func (z *Int) SetInt64(x int64) *Int {
return z
}
+// SetUint64 sets z to x and returns z.
+func (z *Int) SetUint64(x uint64) *Int {
+ z.abs = z.abs.setUint64(uint64(x))
+ z.neg = false
+ return z
+}
+
// NewInt allocates and returns a new Int set to x.
func NewInt(x int64) *Int {
return new(Int).SetInt64(x)
@@ -412,7 +419,7 @@ func (x *Int) Format(s fmt.State, ch rune) {
if precisionSet {
switch {
case len(digits) < precision:
- zeroes = precision - len(digits) // count of zero padding
+ zeroes = precision - len(digits) // count of zero padding
case digits == "0" && precision == 0:
return // print nothing if zero value (x == 0) and zero precision ("." or ".0")
}
@@ -519,6 +526,19 @@ func (x *Int) Int64() int64 {
return v
}
+// Uint64 returns the uint64 representation of x.
+// If x cannot be represented in an uint64, the result is undefined.
+func (x *Int) Uint64() uint64 {
+ if len(x.abs) == 0 {
+ return 0
+ }
+ v := uint64(x.abs[0])
+ if _W == 32 && len(x.abs) > 1 {
+ v |= uint64(x.abs[1]) << 32
+ }
+ return v
+}
+
// SetString sets z to the value of s, interpreted in the given base,
// and returns z and a boolean indicating success. If SetString fails,
// the value of z is undefined but the returned value is nil.
@@ -561,19 +581,18 @@ func (x *Int) BitLen() int {
return x.abs.bitLen()
}
-// Exp sets z = x**y mod m and returns z. If m is nil, z = x**y.
+// Exp sets z = x**y mod |m| (i.e. the sign of m is ignored), and returns z.
+// If y <= 0, the result is 1; if m == nil or m == 0, z = x**y.
// See Knuth, volume 2, section 4.6.3.
func (z *Int) Exp(x, y, m *Int) *Int {
if y.neg || len(y.abs) == 0 {
- neg := x.neg
- z.SetInt64(1)
- z.neg = neg
- return z
+ return z.SetInt64(1)
}
+ // y > 0
var mWords nat
if m != nil {
- mWords = m.abs
+ mWords = m.abs // m.abs may be nil for m == 0
}
z.abs = z.abs.expNN(x.abs, y.abs, mWords)
@@ -581,12 +600,12 @@ func (z *Int) Exp(x, y, m *Int) *Int {
return z
}
-// GCD sets z to the greatest common divisor of a and b, which must be
-// positive numbers, and returns z.
+// GCD sets z to the greatest common divisor of a and b, which both must
+// be > 0, and returns z.
// If x and y are not nil, GCD sets x and y such that z = a*x + b*y.
-// If either a or b is not positive, GCD sets z = x = y = 0.
+// If either a or b is <= 0, GCD sets z = x = y = 0.
func (z *Int) GCD(x, y, a, b *Int) *Int {
- if a.neg || b.neg {
+ if a.Sign() <= 0 || b.Sign() <= 0 {
z.SetInt64(0)
if x != nil {
x.SetInt64(0)
@@ -596,6 +615,9 @@ func (z *Int) GCD(x, y, a, b *Int) *Int {
}
return z
}
+ if x == nil && y == nil {
+ return z.binaryGCD(a, b)
+ }
A := new(Int).Set(a)
B := new(Int).Set(b)
@@ -640,6 +662,63 @@ func (z *Int) GCD(x, y, a, b *Int) *Int {
return z
}
+// binaryGCD sets z to the greatest common divisor of a and b, which both must
+// be > 0, and returns z.
+// See Knuth, The Art of Computer Programming, Vol. 2, Section 4.5.2, Algorithm B.
+func (z *Int) binaryGCD(a, b *Int) *Int {
+ u := z
+ v := new(Int)
+
+ // use one Euclidean iteration to ensure that u and v are approx. the same size
+ switch {
+ case len(a.abs) > len(b.abs):
+ u.Set(b)
+ v.Rem(a, b)
+ case len(a.abs) < len(b.abs):
+ u.Set(a)
+ v.Rem(b, a)
+ default:
+ u.Set(a)
+ v.Set(b)
+ }
+
+ // v might be 0 now
+ if len(v.abs) == 0 {
+ return u
+ }
+ // u > 0 && v > 0
+
+ // determine largest k such that u = u' << k, v = v' << k
+ k := u.abs.trailingZeroBits()
+ if vk := v.abs.trailingZeroBits(); vk < k {
+ k = vk
+ }
+ u.Rsh(u, k)
+ v.Rsh(v, k)
+
+ // determine t (we know that u > 0)
+ t := new(Int)
+ if u.abs[0]&1 != 0 {
+ // u is odd
+ t.Neg(v)
+ } else {
+ t.Set(u)
+ }
+
+ for len(t.abs) > 0 {
+ // reduce t
+ t.Rsh(t, t.abs.trailingZeroBits())
+ if t.neg {
+ v.Neg(t)
+ } else {
+ u.Set(t)
+ }
+ t.Sub(u, v)
+ }
+
+ return u.Lsh(u, k)
+}
+
// ProbablyPrime performs n Miller-Rabin tests to check whether x is prime.
// If it returns true, x is prime with probability 1 - 1/4^n.
// If it returns false, x is not prime.
@@ -697,6 +776,13 @@ func (z *Int) Rsh(x *Int, n uint) *Int {
// Bit returns the value of the i'th bit of x. That is, it
// returns (x>>i)&1. The bit index i must be >= 0.
func (x *Int) Bit(i int) uint {
+ if i == 0 {
+ // optimization for common case: odd/even test of x
+ if len(x.abs) > 0 {
+ return uint(x.abs[0] & 1) // bit 0 is same for -x
+ }
+ return 0
+ }
if i < 0 {
panic("negative bit index")
}
@@ -894,3 +980,19 @@ func (z *Int) GobDecode(buf []byte) error {
z.abs = z.abs.setBytes(buf[1:])
return nil
}
+
+// MarshalJSON implements the json.Marshaler interface.
+func (x *Int) MarshalJSON() ([]byte, error) {
+ // TODO(gri): get rid of the []byte/string conversions
+ return []byte(x.String()), nil
+}
+
+// UnmarshalJSON implements the json.Unmarshaler interface.
+func (z *Int) UnmarshalJSON(x []byte) error {
+ // TODO(gri): get rid of the []byte/string conversions
+ _, ok := z.SetString(string(x), 0)
+ if !ok {
+ return fmt.Errorf("math/big: cannot unmarshal %s into a *big.Int", x)
+ }
+ return nil
+}
diff --git a/src/pkg/math/big/int_test.go b/src/pkg/math/big/int_test.go
index 9700a9b5a..6c981e775 100644
--- a/src/pkg/math/big/int_test.go
+++ b/src/pkg/math/big/int_test.go
@@ -8,6 +8,7 @@ import (
"bytes"
"encoding/gob"
"encoding/hex"
+ "encoding/json"
"fmt"
"math/rand"
"testing"
@@ -642,7 +643,7 @@ func TestSetBytes(t *testing.T) {
func checkBytes(b []byte) bool {
b2 := new(Int).SetBytes(b).Bytes()
- return bytes.Compare(b, b2) == 0
+ return bytes.Equal(b, b2)
}
func TestBytes(t *testing.T) {
@@ -766,8 +767,10 @@ var expTests = []struct {
x, y, m string
out string
}{
+ {"5", "-7", "", "1"},
+ {"-5", "-7", "", "1"},
{"5", "0", "", "1"},
- {"-5", "0", "", "-1"},
+ {"-5", "0", "", "1"},
{"5", "1", "", "5"},
{"-5", "1", "", "-5"},
{"-2", "3", "2", "0"},
@@ -778,6 +781,7 @@ var expTests = []struct {
{"0x8000000000000000", "3", "6719", "5447"},
{"0x8000000000000000", "1000", "6719", "1603"},
{"0x8000000000000000", "1000000", "6719", "3199"},
+ {"0x8000000000000000", "-1000000", "6719", "1"},
{
"2938462938472983472983659726349017249287491026512746239764525612965293865296239471239874193284792387498274256129746192347",
"298472983472983471903246121093472394872319615612417471234712061",
@@ -806,25 +810,33 @@ func TestExp(t *testing.T) {
continue
}
- z := y.Exp(x, y, m)
- if !isNormalized(z) {
- t.Errorf("#%d: %v is not normalized", i, *z)
+ z1 := new(Int).Exp(x, y, m)
+ if !isNormalized(z1) {
+ t.Errorf("#%d: %v is not normalized", i, *z1)
}
- if z.Cmp(out) != 0 {
- t.Errorf("#%d: got %s want %s", i, z, out)
+ if z1.Cmp(out) != 0 {
+ t.Errorf("#%d: got %s want %s", i, z1, out)
+ }
+
+ if m == nil {
+ // the result should be the same as for m == 0;
+ // specifically, there should be no div-zero panic
+ m = &Int{abs: nat{}} // m != nil && len(m.abs) == 0
+ z2 := new(Int).Exp(x, y, m)
+ if z2.Cmp(z1) != 0 {
+ t.Errorf("#%d: got %s want %s", i, z1, z2)
+ }
}
}
}
func checkGcd(aBytes, bBytes []byte) bool {
- a := new(Int).SetBytes(aBytes)
- b := new(Int).SetBytes(bBytes)
-
x := new(Int)
y := new(Int)
- d := new(Int)
+ a := new(Int).SetBytes(aBytes)
+ b := new(Int).SetBytes(bBytes)
- d.GCD(x, y, a, b)
+ d := new(Int).GCD(x, y, a, b)
x.Mul(x, a)
y.Mul(y, b)
x.Add(x, y)
@@ -833,32 +845,70 @@ func checkGcd(aBytes, bBytes []byte) bool {
}
var gcdTests = []struct {
- a, b int64
- d, x, y int64
+ d, x, y, a, b string
}{
- {120, 23, 1, -9, 47},
-}
-
-func TestGcd(t *testing.T) {
- for i, test := range gcdTests {
- a := NewInt(test.a)
- b := NewInt(test.b)
+ // a <= 0 || b <= 0
+ {"0", "0", "0", "0", "0"},
+ {"0", "0", "0", "0", "7"},
+ {"0", "0", "0", "11", "0"},
+ {"0", "0", "0", "-77", "35"},
+ {"0", "0", "0", "64515", "-24310"},
+ {"0", "0", "0", "-64515", "-24310"},
+
+ {"1", "-9", "47", "120", "23"},
+ {"7", "1", "-2", "77", "35"},
+ {"935", "-3", "8", "64515", "24310"},
+ {"935000000000000000", "-3", "8", "64515000000000000000", "24310000000000000000"},
+ {"1", "-221", "22059940471369027483332068679400581064239780177629666810348940098015901108344", "98920366548084643601728869055592650835572950932266967461790948584315647051443", "991"},
+
+ // test early exit (after one Euclidean iteration) in binaryGCD
+ {"1", "", "", "1", "98920366548084643601728869055592650835572950932266967461790948584315647051443"},
+}
+
+func testGcd(t *testing.T, d, x, y, a, b *Int) {
+ var X *Int
+ if x != nil {
+ X = new(Int)
+ }
+ var Y *Int
+ if y != nil {
+ Y = new(Int)
+ }
- x := new(Int)
- y := new(Int)
- d := new(Int)
+ D := new(Int).GCD(X, Y, a, b)
+ if D.Cmp(d) != 0 {
+ t.Errorf("GCD(%s, %s): got d = %s, want %s", a, b, D, d)
+ }
+ if x != nil && X.Cmp(x) != 0 {
+ t.Errorf("GCD(%s, %s): got x = %s, want %s", a, b, X, x)
+ }
+ if y != nil && Y.Cmp(y) != 0 {
+ t.Errorf("GCD(%s, %s): got y = %s, want %s", a, b, Y, y)
+ }
- expectedX := NewInt(test.x)
- expectedY := NewInt(test.y)
- expectedD := NewInt(test.d)
+ // binaryGCD requires a > 0 && b > 0
+ if a.Sign() <= 0 || b.Sign() <= 0 {
+ return
+ }
- d.GCD(x, y, a, b)
+ D.binaryGCD(a, b)
+ if D.Cmp(d) != 0 {
+ t.Errorf("binaryGcd(%s, %s): got d = %s, want %s", a, b, D, d)
+ }
+}
- if expectedX.Cmp(x) != 0 ||
- expectedY.Cmp(y) != 0 ||
- expectedD.Cmp(d) != 0 {
- t.Errorf("#%d got (%s %s %s) want (%s %s %s)", i, x, y, d, expectedX, expectedY, expectedD)
- }
+func TestGcd(t *testing.T) {
+ for _, test := range gcdTests {
+ d, _ := new(Int).SetString(test.d, 0)
+ x, _ := new(Int).SetString(test.x, 0)
+ y, _ := new(Int).SetString(test.y, 0)
+ a, _ := new(Int).SetString(test.a, 0)
+ b, _ := new(Int).SetString(test.b, 0)
+
+ testGcd(t, d, nil, nil, a, b)
+ testGcd(t, d, x, nil, a, b)
+ testGcd(t, d, nil, y, a, b)
+ testGcd(t, d, x, y, a, b)
}
quick.Check(checkGcd, nil)
@@ -1085,6 +1135,36 @@ func TestInt64(t *testing.T) {
}
}
+var uint64Tests = []uint64{
+ 0,
+ 1,
+ 4294967295,
+ 4294967296,
+ 8589934591,
+ 8589934592,
+ 9223372036854775807,
+ 9223372036854775808,
+ 18446744073709551615, // 1<<64 - 1
+}
+
+func TestUint64(t *testing.T) {
+ in := new(Int)
+ for i, testVal := range uint64Tests {
+ in.SetUint64(testVal)
+ out := in.Uint64()
+
+ if out != testVal {
+ t.Errorf("#%d got %d want %d", i, out, testVal)
+ }
+
+ str := fmt.Sprint(testVal)
+ strOut := in.String()
+ if strOut != str {
+ t.Errorf("#%d.String got %s want %s", i, strOut, str)
+ }
+ }
+}
+
var bitwiseTests = []struct {
x, y string
and, or, xor, andNot string
@@ -1368,8 +1448,12 @@ func TestModInverse(t *testing.T) {
}
}
-// used by TestIntGobEncoding and TestRatGobEncoding
-var gobEncodingTests = []string{
+var encodingTests = []string{
+ "-539345864568634858364538753846587364875430589374589",
+ "-678645873",
+ "-100",
+ "-2",
+ "-1",
"0",
"1",
"2",
@@ -1383,26 +1467,37 @@ func TestIntGobEncoding(t *testing.T) {
var medium bytes.Buffer
enc := gob.NewEncoder(&medium)
dec := gob.NewDecoder(&medium)
- for i, test := range gobEncodingTests {
- for j := 0; j < 2; j++ {
- medium.Reset() // empty buffer for each test case (in case of failures)
- stest := test
- if j != 0 {
- // negative numbers
- stest = "-" + test
- }
- var tx Int
- tx.SetString(stest, 10)
- if err := enc.Encode(&tx); err != nil {
- t.Errorf("#%d%c: encoding failed: %s", i, 'a'+j, err)
- }
- var rx Int
- if err := dec.Decode(&rx); err != nil {
- t.Errorf("#%d%c: decoding failed: %s", i, 'a'+j, err)
- }
- if rx.Cmp(&tx) != 0 {
- t.Errorf("#%d%c: transmission failed: got %s want %s", i, 'a'+j, &rx, &tx)
- }
+ for _, test := range encodingTests {
+ medium.Reset() // empty buffer for each test case (in case of failures)
+ var tx Int
+ tx.SetString(test, 10)
+ if err := enc.Encode(&tx); err != nil {
+ t.Errorf("encoding of %s failed: %s", &tx, err)
+ }
+ var rx Int
+ if err := dec.Decode(&rx); err != nil {
+ t.Errorf("decoding of %s failed: %s", &tx, err)
+ }
+ if rx.Cmp(&tx) != 0 {
+ t.Errorf("transmission of %s failed: got %s want %s", &tx, &rx, &tx)
+ }
+ }
+}
+
+func TestIntJSONEncoding(t *testing.T) {
+ for _, test := range encodingTests {
+ var tx Int
+ tx.SetString(test, 10)
+ b, err := json.Marshal(&tx)
+ if err != nil {
+ t.Errorf("marshaling of %s failed: %s", &tx, err)
+ }
+ var rx Int
+ if err := json.Unmarshal(b, &rx); err != nil {
+ t.Errorf("unmarshaling of %s failed: %s", &tx, err)
+ }
+ if rx.Cmp(&tx) != 0 {
+ t.Errorf("JSON encoding of %s failed: got %s want %s", &tx, &rx, &tx)
}
}
}
diff --git a/src/pkg/math/big/nat.go b/src/pkg/math/big/nat.go
index eaa6ff066..9d09f97b7 100644
--- a/src/pkg/math/big/nat.go
+++ b/src/pkg/math/big/nat.go
@@ -236,7 +236,7 @@ func karatsubaSub(z, x nat, n int) {
// Operands that are shorter than karatsubaThreshold are multiplied using
// "grade school" multiplication; for longer operands the Karatsuba algorithm
// is used.
-var karatsubaThreshold int = 32 // computed by calibrate.go
+var karatsubaThreshold int = 40 // computed by calibrate.go
// karatsuba multiplies x and y and leaves the result in z.
// Both x and y must have the same length n and n must be a
@@ -342,7 +342,7 @@ func alias(x, y nat) bool {
return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
}
-// addAt implements z += x*(1<<(_W*i)); z must be long enough.
+// addAt implements z += x<<(_W*i); z must be long enough.
// (we don't use nat.add because we need z to stay the same
// slice, and we don't need to normalize z after each addition)
func addAt(z, x nat, i int) {
@@ -396,7 +396,7 @@ func (z nat) mul(x, y nat) nat {
}
// use basic multiplication if the numbers are small
- if n < karatsubaThreshold || n < 2 {
+ if n < karatsubaThreshold {
z = z.make(m + n)
basicMul(z, x, y)
return z.norm()
@@ -405,8 +405,8 @@ func (z nat) mul(x, y nat) nat {
// determine Karatsuba length k such that
//
- // x = x1*b + x0
- // y = y1*b + y0 (and k <= len(y), which implies k <= len(x))
+ // x = xh*b + x0 (0 <= x0 < b)
+ // y = yh*b + y0 (0 <= y0 < b)
// b = 1<<(_W*k) ("base" of digits xi, yi)
//
k := karatsubaLen(n)
@@ -417,27 +417,44 @@ func (z nat) mul(x, y nat) nat {
y0 := y[0:k] // y0 is not normalized
z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
karatsuba(z, x0, y0)
- z = z[0 : m+n] // z has final length but may be incomplete, upper portion is garbage
+ z = z[0 : m+n] // z has final length but may be incomplete
+ z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
- // If x1 and/or y1 are not 0, add missing terms to z explicitly:
+ // If xh != 0 or yh != 0, add the missing terms to z. For
+ //
+ // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
+ // yh = y1*b (0 <= y1 < b)
+ //
+ // the missing terms are
+ //
+ // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
//
- // m+n 2*k 0
- // z = [ ... | x0*y0 ]
- // + [ x1*y1 ]
- // + [ x1*y0 ]
- // + [ x0*y1 ]
+ // since all the yi for i > 1 are 0 by choice of k: If any of them
+ // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
+ // be a larger valid threshold contradicting the assumption about k.
//
if k < n || m != n {
- x1 := x[k:] // x1 is normalized because x is
- y1 := y[k:] // y1 is normalized because y is
var t nat
- t = t.mul(x1, y1)
- copy(z[2*k:], t)
- z[2*k+len(t):].clear() // upper portion of z is garbage
- t = t.mul(x1, y0.norm())
- addAt(z, t, k)
- t = t.mul(x0.norm(), y1)
+
+ // add x0*y1*b
+ x0 := x0.norm()
+ y1 := y[k:] // y1 is normalized because y is
+ t = t.mul(x0, y1) // update t so we don't lose t's underlying array
addAt(z, t, k)
+
+ // add xi*y0<<i, xi*y1*b<<(i+k)
+ y0 := y0.norm()
+ for i := k; i < len(x); i += k {
+ xi := x[i:]
+ if len(xi) > k {
+ xi = xi[:k]
+ }
+ xi = xi.norm()
+ t = t.mul(xi, y0)
+ addAt(z, t, i)
+ t = t.mul(xi, y1)
+ addAt(z, t, i+k)
+ }
}
return z.norm()
@@ -493,14 +510,9 @@ func (z nat) div(z2, u, v nat) (q, r nat) {
}
if len(v) == 1 {
- var rprime Word
- q, rprime = z.divW(u, v[0])
- if rprime > 0 {
- r = z2.make(1)
- r[0] = rprime
- } else {
- r = z2.make(0)
- }
+ var r2 Word
+ q, r2 = z.divW(u, v[0])
+ r = z2.setWord(r2)
return
}
@@ -740,7 +752,7 @@ func (x nat) string(charset string) string {
// convert power of two and non power of two bases separately
if b == b&-b {
// shift is base-b digit size in bits
- shift := uint(trailingZeroBits(b)) // shift > 0 because b >= 2
+ shift := trailingZeroBits(b) // shift > 0 because b >= 2
mask := Word(1)<<shift - 1
w := x[0]
nbits := uint(_W) // number of unprocessed bits in w
@@ -814,18 +826,18 @@ func (x nat) string(charset string) string {
// Convert words of q to base b digits in s. If q is large, it is recursively "split in half"
// by nat/nat division using tabulated divisors. Otherwise, it is converted iteratively using
-// repeated nat/Word divison.
+// repeated nat/Word division.
//
-// The iterative method processes n Words by n divW() calls, each of which visits every Word in the
-// incrementally shortened q for a total of n + (n-1) + (n-2) ... + 2 + 1, or n(n+1)/2 divW()'s.
-// Recursive conversion divides q by its approximate square root, yielding two parts, each half
+// The iterative method processes n Words by n divW() calls, each of which visits every Word in the
+// incrementally shortened q for a total of n + (n-1) + (n-2) ... + 2 + 1, or n(n+1)/2 divW()'s.
+// Recursive conversion divides q by its approximate square root, yielding two parts, each half
// the size of q. Using the iterative method on both halves means 2 * (n/2)(n/2 + 1)/2 divW()'s
// plus the expensive long div(). Asymptotically, the ratio is favorable at 1/2 the divW()'s, and
-// is made better by splitting the subblocks recursively. Best is to split blocks until one more
-// split would take longer (because of the nat/nat div()) than the twice as many divW()'s of the
-// iterative approach. This threshold is represented by leafSize. Benchmarking of leafSize in the
-// range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and
-// ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for
+// is made better by splitting the subblocks recursively. Best is to split blocks until one more
+// split would take longer (because of the nat/nat div()) than the twice as many divW()'s of the
+// iterative approach. This threshold is represented by leafSize. Benchmarking of leafSize in the
+// range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and
+// ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for
// specific hardware.
//
func (q nat) convertWords(s []byte, charset string, b Word, ndigits int, bb Word, table []divisor) {
@@ -908,8 +920,10 @@ type divisor struct {
ndigits int // digit length of divisor in terms of output base digits
}
-var cacheBase10 [64]divisor // cached divisors for base 10
-var cacheLock sync.Mutex // protects cacheBase10
+var cacheBase10 struct {
+ sync.Mutex
+ table [64]divisor // cached divisors for base 10
+}
// expWW computes x**y
func (z nat) expWW(x, y Word) nat {
@@ -925,34 +939,28 @@ func divisors(m int, b Word, ndigits int, bb Word) []divisor {
// determine k where (bb**leafSize)**(2**k) >= sqrt(x)
k := 1
- for words := leafSize; words < m>>1 && k < len(cacheBase10); words <<= 1 {
+ for words := leafSize; words < m>>1 && k < len(cacheBase10.table); words <<= 1 {
k++
}
- // create new table of divisors or extend and reuse existing table as appropriate
- var table []divisor
- var cached bool
- switch b {
- case 10:
- table = cacheBase10[0:k] // reuse old table for this conversion
- cached = true
- default:
- table = make([]divisor, k) // new table for this conversion
+ // reuse and extend existing table of divisors or create new table as appropriate
+ var table []divisor // for b == 10, table overlaps with cacheBase10.table
+ if b == 10 {
+ cacheBase10.Lock()
+ table = cacheBase10.table[0:k] // reuse old table for this conversion
+ } else {
+ table = make([]divisor, k) // create new table for this conversion
}
// extend table
if table[k-1].ndigits == 0 {
- if cached {
- cacheLock.Lock() // begin critical section
- }
-
// add new entries as needed
var larger nat
for i := 0; i < k; i++ {
if table[i].ndigits == 0 {
if i == 0 {
- table[i].bbb = nat(nil).expWW(bb, Word(leafSize))
- table[i].ndigits = ndigits * leafSize
+ table[0].bbb = nat(nil).expWW(bb, Word(leafSize))
+ table[0].ndigits = ndigits * leafSize
} else {
table[i].bbb = nat(nil).mul(table[i-1].bbb, table[i-1].bbb)
table[i].ndigits = 2 * table[i-1].ndigits
@@ -968,10 +976,10 @@ func divisors(m int, b Word, ndigits int, bb Word) []divisor {
table[i].nbits = table[i].bbb.bitLen()
}
}
+ }
- if cached {
- cacheLock.Unlock() // end critical section
- }
+ if b == 10 {
+ cacheBase10.Unlock()
}
return table
@@ -993,10 +1001,9 @@ var deBruijn64Lookup = []byte{
54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6,
}
-// trailingZeroBits returns the number of consecutive zero bits on the right
-// side of the given Word.
-// See Knuth, volume 4, section 7.3.1
-func trailingZeroBits(x Word) int {
+// trailingZeroBits returns the number of consecutive least significant zero
+// bits of x.
+func trailingZeroBits(x Word) uint {
// x & -x leaves only the right-most bit set in the word. Let k be the
// index of that bit. Since only a single bit is set, the value is two
// to the power of k. Multiplying by a power of two is equivalent to
@@ -1005,18 +1012,33 @@ func trailingZeroBits(x Word) int {
// Therefore, if we have a left shifted version of this constant we can
// find by how many bits it was shifted by looking at which six bit
// substring ended up at the top of the word.
+ // (Knuth, volume 4, section 7.3.1)
switch _W {
case 32:
- return int(deBruijn32Lookup[((x&-x)*deBruijn32)>>27])
+ return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27])
case 64:
- return int(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
+ return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
default:
- panic("Unknown word size")
+ panic("unknown word size")
}
return 0
}
+// trailingZeroBits returns the number of consecutive least significant zero
+// bits of x.
+func (x nat) trailingZeroBits() uint {
+ if len(x) == 0 {
+ return 0
+ }
+ var i uint
+ for x[i] == 0 {
+ i++
+ }
+ // x[i] != 0
+ return i*_W + trailingZeroBits(x[i])
+}
+
// z = x << s
func (z nat) shl(x nat, s uint) nat {
m := len(x)
@@ -1169,29 +1191,6 @@ func (x nat) modW(d Word) (r Word) {
return divWVW(q, 0, x, d)
}
-// powersOfTwoDecompose finds q and k with x = q * 1<<k and q is odd, or q and k are 0.
-func (x nat) powersOfTwoDecompose() (q nat, k int) {
- if len(x) == 0 {
- return x, 0
- }
-
- // One of the words must be non-zero by definition,
- // so this loop will terminate with i < len(x), and
- // i is the number of 0 words.
- i := 0
- for x[i] == 0 {
- i++
- }
- n := trailingZeroBits(x[i]) // x[i] != 0
-
- q = make(nat, len(x)-i)
- shrVU(q, x[i:], uint(n))
-
- q = q.norm()
- k = i*_W + n
- return
-}
-
// random creates a random integer in [0..limit), using the space in z if
// possible. n is the bit length of limit.
func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
@@ -1207,17 +1206,19 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
mask := Word((1 << bitLengthOfMSW) - 1)
for {
- for i := range z {
- switch _W {
- case 32:
+ switch _W {
+ case 32:
+ for i := range z {
z[i] = Word(rand.Uint32())
- case 64:
+ }
+ case 64:
+ for i := range z {
z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
}
+ default:
+ panic("unknown word size")
}
-
z[len(limit)-1] &= mask
-
if z.cmp(limit) < 0 {
break
}
@@ -1226,11 +1227,11 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
return z.norm()
}
-// If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It
-// reuses the storage of z if possible.
+// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
+// otherwise it sets z to x**y. The result is the value of z.
func (z nat) expNN(x, y, m nat) nat {
if alias(z, x) || alias(z, y) {
- // We cannot allow in place modification of x or y.
+ // We cannot allow in-place modification of x or y.
z = nil
}
@@ -1239,15 +1240,24 @@ func (z nat) expNN(x, y, m nat) nat {
z[0] = 1
return z
}
+ // y > 0
- if m != nil {
+ if len(m) != 0 {
// We likely end up being as long as the modulus.
z = z.make(len(m))
}
z = z.set(x)
- v := y[len(y)-1]
- // It's invalid for the most significant word to be zero, therefore we
- // will find a one bit.
+
+ // If the base is non-trivial and the exponent is large, we use
+ // 4-bit, windowed exponentiation. This involves precomputing 14 values
+ // (x^2...x^15) but then reduces the number of multiply-reduces by a
+ // third. Even for a 32-bit exponent, this reduces the number of
+ // operations.
+ if len(x) > 1 && len(y) > 1 && len(m) > 0 {
+ return z.expNNWindowed(x, y, m)
+ }
+
+ v := y[len(y)-1] // v > 0 because y is normalized and y > 0
shift := leadingZeros(v) + 1
v <<= shift
var q nat
@@ -1259,15 +1269,21 @@ func (z nat) expNN(x, y, m nat) nat {
// we also multiply by x, thus adding one to the power.
w := _W - int(shift)
+ // zz and r are used to avoid allocating in mul and div as
+ // otherwise the arguments would alias.
+ var zz, r nat
for j := 0; j < w; j++ {
- z = z.mul(z, z)
+ zz = zz.mul(z, z)
+ zz, z = z, zz
if v&mask != 0 {
- z = z.mul(z, x)
+ zz = zz.mul(z, x)
+ zz, z = z, zz
}
- if m != nil {
- q, z = q.div(z, z, m)
+ if len(m) != 0 {
+ zz, r = zz.div(r, z, m)
+ zz, r, q, z = q, z, zz, r
}
v <<= 1
@@ -1277,14 +1293,17 @@ func (z nat) expNN(x, y, m nat) nat {
v = y[i]
for j := 0; j < _W; j++ {
- z = z.mul(z, z)
+ zz = zz.mul(z, z)
+ zz, z = z, zz
if v&mask != 0 {
- z = z.mul(z, x)
+ zz = zz.mul(z, x)
+ zz, z = z, zz
}
- if m != nil {
- q, z = q.div(z, z, m)
+ if len(m) != 0 {
+ zz, r = zz.div(r, z, m)
+ zz, r, q, z = q, z, zz, r
}
v <<= 1
@@ -1294,6 +1313,69 @@ func (z nat) expNN(x, y, m nat) nat {
return z.norm()
}
+// expNNWindowed calculates x**y mod m using a fixed, 4-bit window.
+func (z nat) expNNWindowed(x, y, m nat) nat {
+ // zz and r are used to avoid allocating in mul and div as otherwise
+ // the arguments would alias.
+ var zz, r nat
+
+ const n = 4
+ // powers[i] contains x^i.
+ var powers [1 << n]nat
+ powers[0] = natOne
+ powers[1] = x
+ for i := 2; i < 1<<n; i += 2 {
+ p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
+ *p = p.mul(*p2, *p2)
+ zz, r = zz.div(r, *p, m)
+ *p, r = r, *p
+ *p1 = p1.mul(*p, x)
+ zz, r = zz.div(r, *p1, m)
+ *p1, r = r, *p1
+ }
+
+ z = z.setWord(1)
+
+ for i := len(y) - 1; i >= 0; i-- {
+ yi := y[i]
+ for j := 0; j < _W; j += n {
+ if i != len(y)-1 || j != 0 {
+ // Unrolled loop for significant performance
+ // gain. Use go test -bench=".*" in crypto/rsa
+ // to check performance before making changes.
+ zz = zz.mul(z, z)
+ zz, z = z, zz
+ zz, r = zz.div(r, z, m)
+ z, r = r, z
+
+ zz = zz.mul(z, z)
+ zz, z = z, zz
+ zz, r = zz.div(r, z, m)
+ z, r = r, z
+
+ zz = zz.mul(z, z)
+ zz, z = z, zz
+ zz, r = zz.div(r, z, m)
+ z, r = r, z
+
+ zz = zz.mul(z, z)
+ zz, z = z, zz
+ zz, r = zz.div(r, z, m)
+ z, r = r, z
+ }
+
+ zz = zz.mul(z, powers[yi>>(_W-n)])
+ zz, z = z, zz
+ zz, r = zz.div(r, z, m)
+ z, r = r, z
+
+ yi <<= n
+ }
+ }
+
+ return z.norm()
+}
+
// probablyPrime performs reps Miller-Rabin tests to check whether n is prime.
// If it returns true, n is prime with probability 1 - 1/4^reps.
// If it returns false, n is not prime.
@@ -1343,8 +1425,9 @@ func (n nat) probablyPrime(reps int) bool {
}
nm1 := nat(nil).sub(n, natOne)
- // 1<<k * q = nm1;
- q, k := nm1.powersOfTwoDecompose()
+ // determine q, k such that nm1 = q << k
+ k := nm1.trailingZeroBits()
+ q := nat(nil).shr(nm1, k)
nm3 := nat(nil).sub(nm1, natTwo)
rand := rand.New(rand.NewSource(int64(n[0])))
@@ -1360,7 +1443,7 @@ NextRandom:
if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
continue
}
- for j := 1; j < k; j++ {
+ for j := uint(1); j < k; j++ {
y = y.mul(y, y)
quotient, y = quotient.div(y, y, n)
if y.cmp(nm1) == 0 {
diff --git a/src/pkg/math/big/nat_test.go b/src/pkg/math/big/nat_test.go
index becde5d17..2dd7bf639 100644
--- a/src/pkg/math/big/nat_test.go
+++ b/src/pkg/math/big/nat_test.go
@@ -6,6 +6,7 @@ package big
import (
"io"
+ "runtime"
"strings"
"testing"
)
@@ -62,6 +63,36 @@ var prodNN = []argNN{
{nat{0, 0, 991 * 991}, nat{0, 991}, nat{0, 991}},
{nat{1 * 991, 2 * 991, 3 * 991, 4 * 991}, nat{1, 2, 3, 4}, nat{991}},
{nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}},
+ // 3^100 * 3^28 = 3^128
+ {
+ natFromString("11790184577738583171520872861412518665678211592275841109096961"),
+ natFromString("515377520732011331036461129765621272702107522001"),
+ natFromString("22876792454961"),
+ },
+ // z = 111....1 (70000 digits)
+ // x = 10^(99*700) + ... + 10^1400 + 10^700 + 1
+ // y = 111....1 (700 digits, larger than Karatsuba threshold on 32-bit and 64-bit)
+ {
+ natFromString(strings.Repeat("1", 70000)),
+ natFromString("1" + strings.Repeat(strings.Repeat("0", 699)+"1", 99)),
+ natFromString(strings.Repeat("1", 700)),
+ },
+ // z = 111....1 (20000 digits)
+ // x = 10^10000 + 1
+ // y = 111....1 (10000 digits)
+ {
+ natFromString(strings.Repeat("1", 20000)),
+ natFromString("1" + strings.Repeat("0", 9999) + "1"),
+ natFromString(strings.Repeat("1", 10000)),
+ },
+}
+
+func natFromString(s string) nat {
+ x, _, err := nat(nil).scan(strings.NewReader(s), 0)
+ if err != nil {
+ panic(err)
+ }
+ return x
}
func TestSet(t *testing.T) {
@@ -135,26 +166,43 @@ func TestMulRangeN(t *testing.T) {
}
}
-var mulArg, mulTmp nat
-
-func init() {
- const n = 1000
- mulArg = make(nat, n)
- for i := 0; i < n; i++ {
- mulArg[i] = _M
+// allocBytes returns the number of bytes allocated by invoking f.
+func allocBytes(f func()) uint64 {
+ var stats runtime.MemStats
+ runtime.ReadMemStats(&stats)
+ t := stats.TotalAlloc
+ f()
+ runtime.ReadMemStats(&stats)
+ return stats.TotalAlloc - t
+}
+
+// TestMulUnbalanced tests that multiplying numbers of different lengths
+// does not cause deep recursion and in turn allocate too much memory.
+// Test case for issue 3807.
+func TestMulUnbalanced(t *testing.T) {
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
+ x := rndNat(50000)
+ y := rndNat(40)
+ allocSize := allocBytes(func() {
+ nat(nil).mul(x, y)
+ })
+ inputSize := uint64(len(x)+len(y)) * _S
+ if ratio := allocSize / uint64(inputSize); ratio > 10 {
+ t.Errorf("multiplication uses too much memory (%d > %d times the size of inputs)", allocSize, ratio)
}
}
-func benchmarkMulLoad() {
- for j := 1; j <= 10; j++ {
- x := mulArg[0 : j*100]
- mulTmp.mul(x, x)
- }
+func rndNat(n int) nat {
+ return nat(rndV(n)).norm()
}
func BenchmarkMul(b *testing.B) {
+ mulx := rndNat(1e4)
+ muly := rndNat(1e4)
+ b.ResetTimer()
for i := 0; i < b.N; i++ {
- benchmarkMulLoad()
+ var z nat
+ z.mul(mulx, muly)
}
}
@@ -362,6 +410,20 @@ func TestScanPi(t *testing.T) {
}
}
+func TestScanPiParallel(t *testing.T) {
+ const n = 2
+ c := make(chan int)
+ for i := 0; i < n; i++ {
+ go func() {
+ TestScanPi(t)
+ c <- 0
+ }()
+ }
+ for i := 0; i < n; i++ {
+ <-c
+ }
+}
+
func BenchmarkScanPi(b *testing.B) {
for i := 0; i < b.N; i++ {
var x nat
@@ -369,6 +431,28 @@ func BenchmarkScanPi(b *testing.B) {
}
}
+func BenchmarkStringPiParallel(b *testing.B) {
+ var x nat
+ x, _, _ = x.scan(strings.NewReader(pi), 0)
+ if x.decimalString() != pi {
+ panic("benchmark incorrect: conversion failed")
+ }
+ n := runtime.GOMAXPROCS(0)
+ m := b.N / n // n*m <= b.N due to flooring, but the error is neglibible (n is not very large)
+ c := make(chan int, n)
+ for i := 0; i < n; i++ {
+ go func() {
+ for j := 0; j < m; j++ {
+ x.decimalString()
+ }
+ c <- 0
+ }()
+ }
+ for i := 0; i < n; i++ {
+ <-c
+ }
+}
+
func BenchmarkScan10Base2(b *testing.B) { ScanHelper(b, 2, 10, 10) }
func BenchmarkScan100Base2(b *testing.B) { ScanHelper(b, 2, 10, 100) }
func BenchmarkScan1000Base2(b *testing.B) { ScanHelper(b, 2, 10, 1000) }
@@ -463,13 +547,13 @@ func BenchmarkLeafSize13(b *testing.B) { LeafSizeHelper(b, 10, 13) }
func BenchmarkLeafSize14(b *testing.B) { LeafSizeHelper(b, 10, 14) }
func BenchmarkLeafSize15(b *testing.B) { LeafSizeHelper(b, 10, 15) }
func BenchmarkLeafSize16(b *testing.B) { LeafSizeHelper(b, 10, 16) }
-func BenchmarkLeafSize32(b *testing.B) { LeafSizeHelper(b, 10, 32) } // try some large lengths
+func BenchmarkLeafSize32(b *testing.B) { LeafSizeHelper(b, 10, 32) } // try some large lengths
func BenchmarkLeafSize64(b *testing.B) { LeafSizeHelper(b, 10, 64) }
func LeafSizeHelper(b *testing.B, base Word, size int) {
b.StopTimer()
originalLeafSize := leafSize
- resetTable(cacheBase10[:])
+ resetTable(cacheBase10.table[:])
leafSize = size
b.StartTimer()
@@ -486,7 +570,7 @@ func LeafSizeHelper(b *testing.B, base Word, size int) {
}
b.StopTimer()
- resetTable(cacheBase10[:])
+ resetTable(cacheBase10.table[:])
leafSize = originalLeafSize
b.StartTimer()
}
@@ -616,14 +700,23 @@ func TestModW(t *testing.T) {
}
func TestTrailingZeroBits(t *testing.T) {
- var x Word
- x--
- for i := 0; i < _W; i++ {
- if trailingZeroBits(x) != i {
- t.Errorf("Failed at step %d: x: %x got: %d", i, x, trailingZeroBits(x))
+ x := Word(1)
+ for i := uint(0); i <= _W; i++ {
+ n := trailingZeroBits(x)
+ if n != i%_W {
+ t.Errorf("got trailingZeroBits(%#x) = %d; want %d", x, n, i%_W)
}
x <<= 1
}
+
+ y := nat(nil).set(natOne)
+ for i := uint(0); i <= 3*_W; i++ {
+ n := y.trailingZeroBits()
+ if n != i {
+ t.Errorf("got 0x%s.trailingZeroBits() = %d; want %d", y.string(lowercaseDigits[0:16]), n, i)
+ }
+ y = y.shl(y, 1)
+ }
}
var expNNTests = []struct {
diff --git a/src/pkg/math/big/rat.go b/src/pkg/math/big/rat.go
index 7bd83fc0f..3e6473d92 100644
--- a/src/pkg/math/big/rat.go
+++ b/src/pkg/math/big/rat.go
@@ -10,14 +10,17 @@ import (
"encoding/binary"
"errors"
"fmt"
+ "math"
"strings"
)
// A Rat represents a quotient a/b of arbitrary precision.
// The zero value for a Rat represents the value 0.
type Rat struct {
- a Int
- b nat // len(b) == 0 acts like b == 1
+ // To make zero values for Rat work w/o initialization,
+ // a zero value of b (len(b) == 0) acts like b == 1.
+ // a.neg determines the sign of the Rat, b.neg is ignored.
+ a, b Int
}
// NewRat creates a new Rat with numerator a and denominator b.
@@ -25,6 +28,156 @@ func NewRat(a, b int64) *Rat {
return new(Rat).SetFrac64(a, b)
}
+// SetFloat64 sets z to exactly f and returns z.
+// If f is not finite, SetFloat returns nil.
+func (z *Rat) SetFloat64(f float64) *Rat {
+ const expMask = 1<<11 - 1
+ bits := math.Float64bits(f)
+ mantissa := bits & (1<<52 - 1)
+ exp := int((bits >> 52) & expMask)
+ switch exp {
+ case expMask: // non-finite
+ return nil
+ case 0: // denormal
+ exp -= 1022
+ default: // normal
+ mantissa |= 1 << 52
+ exp -= 1023
+ }
+
+ shift := 52 - exp
+
+ // Optimisation (?): partially pre-normalise.
+ for mantissa&1 == 0 && shift > 0 {
+ mantissa >>= 1
+ shift--
+ }
+
+ z.a.SetUint64(mantissa)
+ z.a.neg = f < 0
+ z.b.Set(intOne)
+ if shift > 0 {
+ z.b.Lsh(&z.b, uint(shift))
+ } else {
+ z.a.Lsh(&z.a, uint(-shift))
+ }
+ return z.norm()
+}
+
+// isFinite reports whether f represents a finite rational value.
+// It is equivalent to !math.IsNan(f) && !math.IsInf(f, 0).
+func isFinite(f float64) bool {
+ return math.Abs(f) <= math.MaxFloat64
+}
+
+// low64 returns the least significant 64 bits of natural number z.
+func low64(z nat) uint64 {
+ if len(z) == 0 {
+ return 0
+ }
+ if _W == 32 && len(z) > 1 {
+ return uint64(z[1])<<32 | uint64(z[0])
+ }
+ return uint64(z[0])
+}
+
+// quotToFloat returns the non-negative IEEE 754 double-precision
+// value nearest to the quotient a/b, using round-to-even in halfway
+// cases. It does not mutate its arguments.
+// Preconditions: b is non-zero; a and b have no common factors.
+func quotToFloat(a, b nat) (f float64, exact bool) {
+ // TODO(adonovan): specialize common degenerate cases: 1.0, integers.
+ alen := a.bitLen()
+ if alen == 0 {
+ return 0, true
+ }
+ blen := b.bitLen()
+ if blen == 0 {
+ panic("division by zero")
+ }
+
+ // 1. Left-shift A or B such that quotient A/B is in [1<<53, 1<<55).
+ // (54 bits if A<B when they are left-aligned, 55 bits if A>=B.)
+ // This is 2 or 3 more than the float64 mantissa field width of 52:
+ // - the optional extra bit is shifted away in step 3 below.
+ // - the high-order 1 is omitted in float64 "normal" representation;
+ // - the low-order 1 will be used during rounding then discarded.
+ exp := alen - blen
+ var a2, b2 nat
+ a2 = a2.set(a)
+ b2 = b2.set(b)
+ if shift := 54 - exp; shift > 0 {
+ a2 = a2.shl(a2, uint(shift))
+ } else if shift < 0 {
+ b2 = b2.shl(b2, uint(-shift))
+ }
+
+ // 2. Compute quotient and remainder (q, r). NB: due to the
+ // extra shift, the low-order bit of q is logically the
+ // high-order bit of r.
+ var q nat
+ q, r := q.div(a2, a2, b2) // (recycle a2)
+ mantissa := low64(q)
+ haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
+
+ // 3. If quotient didn't fit in 54 bits, re-do division by b2<<1
+ // (in effect---we accomplish this incrementally).
+ if mantissa>>54 == 1 {
+ if mantissa&1 == 1 {
+ haveRem = true
+ }
+ mantissa >>= 1
+ exp++
+ }
+ if mantissa>>53 != 1 {
+ panic("expected exactly 54 bits of result")
+ }
+
+ // 4. Rounding.
+ if -1022-52 <= exp && exp <= -1022 {
+ // Denormal case; lose 'shift' bits of precision.
+ shift := uint64(-1022 - (exp - 1)) // [1..53)
+ lostbits := mantissa & (1<<shift - 1)
+ haveRem = haveRem || lostbits != 0
+ mantissa >>= shift
+ exp = -1023 + 2
+ }
+ // Round q using round-half-to-even.
+ exact = !haveRem
+ if mantissa&1 != 0 {
+ exact = false
+ if haveRem || mantissa&2 != 0 {
+ if mantissa++; mantissa >= 1<<54 {
+ // Complete rollover 11...1 => 100...0, so shift is safe
+ mantissa >>= 1
+ exp++
+ }
+ }
+ }
+ mantissa >>= 1 // discard rounding bit. Mantissa now scaled by 2^53.
+
+ f = math.Ldexp(float64(mantissa), exp-53)
+ if math.IsInf(f, 0) {
+ exact = false
+ }
+ return
+}
+
+// Float64 returns the nearest float64 value to z.
+// If z is exactly representable as a float64, Float64 returns exact=true.
+// If z is negative, so too is f, even if f==0.
+func (z *Rat) Float64() (f float64, exact bool) {
+ b := z.b.abs
+ if len(b) == 0 {
+ b = b.set(natOne) // materialize denominator
+ }
+ f, exact = quotToFloat(z.a.abs, b)
+ if z.a.neg {
+ f = -f
+ }
+ return
+}
+
// SetFrac sets z to a/b and returns z.
func (z *Rat) SetFrac(a, b *Int) *Rat {
z.a.neg = a.neg != b.neg
@@ -36,7 +189,7 @@ func (z *Rat) SetFrac(a, b *Int) *Rat {
babs = nat(nil).set(babs) // make a copy
}
z.a.abs = z.a.abs.set(a.abs)
- z.b = z.b.set(babs)
+ z.b.abs = z.b.abs.set(babs)
return z.norm()
}
@@ -50,21 +203,21 @@ func (z *Rat) SetFrac64(a, b int64) *Rat {
b = -b
z.a.neg = !z.a.neg
}
- z.b = z.b.setUint64(uint64(b))
+ z.b.abs = z.b.abs.setUint64(uint64(b))
return z.norm()
}
// SetInt sets z to x (by making a copy of x) and returns z.
func (z *Rat) SetInt(x *Int) *Rat {
z.a.Set(x)
- z.b = z.b.make(0)
+ z.b.abs = z.b.abs.make(0)
return z
}
// SetInt64 sets z to x and returns z.
func (z *Rat) SetInt64(x int64) *Rat {
z.a.SetInt64(x)
- z.b = z.b.make(0)
+ z.b.abs = z.b.abs.make(0)
return z
}
@@ -72,7 +225,7 @@ func (z *Rat) SetInt64(x int64) *Rat {
func (z *Rat) Set(x *Rat) *Rat {
if z != x {
z.a.Set(&x.a)
- z.b = z.b.set(x.b)
+ z.b.Set(&x.b)
}
return z
}
@@ -97,15 +250,15 @@ func (z *Rat) Inv(x *Rat) *Rat {
panic("division by zero")
}
z.Set(x)
- a := z.b
+ a := z.b.abs
if len(a) == 0 {
- a = a.setWord(1) // materialize numerator
+ a = a.set(natOne) // materialize numerator
}
b := z.a.abs
if b.cmp(natOne) == 0 {
b = b.make(0) // normalize denominator
}
- z.a.abs, z.b = a, b // sign doesn't change
+ z.a.abs, z.b.abs = a, b // sign doesn't change
return z
}
@@ -121,38 +274,26 @@ func (x *Rat) Sign() int {
// IsInt returns true if the denominator of x is 1.
func (x *Rat) IsInt() bool {
- return len(x.b) == 0 || x.b.cmp(natOne) == 0
+ return len(x.b.abs) == 0 || x.b.abs.cmp(natOne) == 0
}
// Num returns the numerator of x; it may be <= 0.
// The result is a reference to x's numerator; it
-// may change if a new value is assigned to x.
+// may change if a new value is assigned to x, and vice versa.
+// The sign of the numerator corresponds to the sign of x.
func (x *Rat) Num() *Int {
return &x.a
}
// Denom returns the denominator of x; it is always > 0.
// The result is a reference to x's denominator; it
-// may change if a new value is assigned to x.
+// may change if a new value is assigned to x, and vice versa.
func (x *Rat) Denom() *Int {
- if len(x.b) == 0 {
- return &Int{abs: nat{1}}
+ x.b.neg = false // the result is always >= 0
+ if len(x.b.abs) == 0 {
+ x.b.abs = x.b.abs.set(natOne) // materialize denominator
}
- return &Int{abs: x.b}
-}
-
-func gcd(x, y nat) nat {
- // Euclidean algorithm.
- var a, b nat
- a = a.set(x)
- b = b.set(y)
- for len(b) != 0 {
- var q, r nat
- _, r = q.div(r, a, b)
- a = b
- b = r
- }
- return a
+ return &x.b
}
func (z *Rat) norm() *Rat {
@@ -160,17 +301,25 @@ func (z *Rat) norm() *Rat {
case len(z.a.abs) == 0:
// z == 0 - normalize sign and denominator
z.a.neg = false
- z.b = z.b.make(0)
- case len(z.b) == 0:
+ z.b.abs = z.b.abs.make(0)
+ case len(z.b.abs) == 0:
// z is normalized int - nothing to do
- case z.b.cmp(natOne) == 0:
+ case z.b.abs.cmp(natOne) == 0:
// z is int - normalize denominator
- z.b = z.b.make(0)
+ z.b.abs = z.b.abs.make(0)
default:
- if f := gcd(z.a.abs, z.b); f.cmp(natOne) != 0 {
- z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f)
- z.b, _ = z.b.div(nil, z.b, f)
+ neg := z.a.neg
+ z.a.neg = false
+ z.b.neg = false
+ if f := NewInt(0).binaryGCD(&z.a, &z.b); f.Cmp(intOne) != 0 {
+ z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f.abs)
+ z.b.abs, _ = z.b.abs.div(nil, z.b.abs, f.abs)
+ if z.b.abs.cmp(natOne) == 0 {
+ // z is int - normalize denominator
+ z.b.abs = z.b.abs.make(0)
+ }
}
+ z.a.neg = neg
}
return z
}
@@ -207,31 +356,31 @@ func scaleDenom(x *Int, f nat) *Int {
// +1 if x > y
//
func (x *Rat) Cmp(y *Rat) int {
- return scaleDenom(&x.a, y.b).Cmp(scaleDenom(&y.a, x.b))
+ return scaleDenom(&x.a, y.b.abs).Cmp(scaleDenom(&y.a, x.b.abs))
}
// Add sets z to the sum x+y and returns z.
func (z *Rat) Add(x, y *Rat) *Rat {
- a1 := scaleDenom(&x.a, y.b)
- a2 := scaleDenom(&y.a, x.b)
+ a1 := scaleDenom(&x.a, y.b.abs)
+ a2 := scaleDenom(&y.a, x.b.abs)
z.a.Add(a1, a2)
- z.b = mulDenom(z.b, x.b, y.b)
+ z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
return z.norm()
}
// Sub sets z to the difference x-y and returns z.
func (z *Rat) Sub(x, y *Rat) *Rat {
- a1 := scaleDenom(&x.a, y.b)
- a2 := scaleDenom(&y.a, x.b)
+ a1 := scaleDenom(&x.a, y.b.abs)
+ a2 := scaleDenom(&y.a, x.b.abs)
z.a.Sub(a1, a2)
- z.b = mulDenom(z.b, x.b, y.b)
+ z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
return z.norm()
}
// Mul sets z to the product x*y and returns z.
func (z *Rat) Mul(x, y *Rat) *Rat {
z.a.Mul(&x.a, &y.a)
- z.b = mulDenom(z.b, x.b, y.b)
+ z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
return z.norm()
}
@@ -241,10 +390,10 @@ func (z *Rat) Quo(x, y *Rat) *Rat {
if len(y.a.abs) == 0 {
panic("division by zero")
}
- a := scaleDenom(&x.a, y.b)
- b := scaleDenom(&y.a, x.b)
+ a := scaleDenom(&x.a, y.b.abs)
+ b := scaleDenom(&y.a, x.b.abs)
z.a.abs = a.abs
- z.b = b.abs
+ z.b.abs = b.abs
z.a.neg = a.neg != b.neg
return z.norm()
}
@@ -286,7 +435,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
}
s = s[sep+1:]
var err error
- if z.b, _, err = z.b.scan(strings.NewReader(s), 10); err != nil {
+ if z.b.abs, _, err = z.b.abs.scan(strings.NewReader(s), 10); err != nil {
return nil, false
}
return z.norm(), true
@@ -317,11 +466,11 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
}
powTen := nat(nil).expNN(natTen, exp.abs, nil)
if exp.neg {
- z.b = powTen
+ z.b.abs = powTen
z.norm()
} else {
z.a.abs = z.a.abs.mul(z.a.abs, powTen)
- z.b = z.b.make(0)
+ z.b.abs = z.b.abs.make(0)
}
return z, true
@@ -330,8 +479,8 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
// String returns a string representation of z in the form "a/b" (even if b == 1).
func (x *Rat) String() string {
s := "/1"
- if len(x.b) != 0 {
- s = "/" + x.b.decimalString()
+ if len(x.b.abs) != 0 {
+ s = "/" + x.b.abs.decimalString()
}
return x.a.String() + s
}
@@ -355,9 +504,9 @@ func (x *Rat) FloatString(prec int) string {
}
return s
}
- // x.b != 0
+ // x.b.abs != 0
- q, r := nat(nil).div(nat(nil), x.a.abs, x.b)
+ q, r := nat(nil).div(nat(nil), x.a.abs, x.b.abs)
p := natOne
if prec > 0 {
@@ -365,11 +514,11 @@ func (x *Rat) FloatString(prec int) string {
}
r = r.mul(r, p)
- r, r2 := r.div(nat(nil), r, x.b)
+ r, r2 := r.div(nat(nil), r, x.b.abs)
// see if we need to round up
r2 = r2.add(r2, r2)
- if x.b.cmp(r2) <= 0 {
+ if x.b.abs.cmp(r2) <= 0 {
r = r.add(r, natOne)
if r.cmp(p) >= 0 {
q = nat(nil).add(q, natOne)
@@ -396,8 +545,8 @@ const ratGobVersion byte = 1
// GobEncode implements the gob.GobEncoder interface.
func (x *Rat) GobEncode() ([]byte, error) {
- buf := make([]byte, 1+4+(len(x.a.abs)+len(x.b))*_S) // extra bytes for version and sign bit (1), and numerator length (4)
- i := x.b.bytes(buf)
+ buf := make([]byte, 1+4+(len(x.a.abs)+len(x.b.abs))*_S) // extra bytes for version and sign bit (1), and numerator length (4)
+ i := x.b.abs.bytes(buf)
j := x.a.abs.bytes(buf[0:i])
n := i - j
if int(uint32(n)) != n {
@@ -427,6 +576,6 @@ func (z *Rat) GobDecode(buf []byte) error {
i := j + binary.BigEndian.Uint32(buf[j-4:j])
z.a.neg = b&1 != 0
z.a.abs = z.a.abs.setBytes(buf[j:i])
- z.b = z.b.setBytes(buf[i:])
+ z.b.abs = z.b.abs.setBytes(buf[i:])
return nil
}
diff --git a/src/pkg/math/big/rat_test.go b/src/pkg/math/big/rat_test.go
index f7f31ae1a..462dfb723 100644
--- a/src/pkg/math/big/rat_test.go
+++ b/src/pkg/math/big/rat_test.go
@@ -8,6 +8,9 @@ import (
"bytes"
"encoding/gob"
"fmt"
+ "math"
+ "strconv"
+ "strings"
"testing"
)
@@ -387,30 +390,19 @@ func TestRatGobEncoding(t *testing.T) {
var medium bytes.Buffer
enc := gob.NewEncoder(&medium)
dec := gob.NewDecoder(&medium)
- for i, test := range gobEncodingTests {
- for j := 0; j < 4; j++ {
- medium.Reset() // empty buffer for each test case (in case of failures)
- stest := test
- if j&1 != 0 {
- // negative numbers
- stest = "-" + test
- }
- if j%2 != 0 {
- // fractions
- stest = stest + "." + test
- }
- var tx Rat
- tx.SetString(stest)
- if err := enc.Encode(&tx); err != nil {
- t.Errorf("#%d%c: encoding failed: %s", i, 'a'+j, err)
- }
- var rx Rat
- if err := dec.Decode(&rx); err != nil {
- t.Errorf("#%d%c: decoding failed: %s", i, 'a'+j, err)
- }
- if rx.Cmp(&tx) != 0 {
- t.Errorf("#%d%c: transmission failed: got %s want %s", i, 'a'+j, &rx, &tx)
- }
+ for _, test := range encodingTests {
+ medium.Reset() // empty buffer for each test case (in case of failures)
+ var tx Rat
+ tx.SetString(test + ".14159265")
+ if err := enc.Encode(&tx); err != nil {
+ t.Errorf("encoding of %s failed: %s", &tx, err)
+ }
+ var rx Rat
+ if err := dec.Decode(&rx); err != nil {
+ t.Errorf("decoding of %s failed: %s", &tx, err)
+ }
+ if rx.Cmp(&tx) != 0 {
+ t.Errorf("transmission of %s failed: got %s want %s", &tx, &rx, &tx)
}
}
}
@@ -454,3 +446,462 @@ func TestIssue2379(t *testing.T) {
t.Errorf("5) got %s want %s", x, q)
}
}
+
+func TestIssue3521(t *testing.T) {
+ a := new(Int)
+ b := new(Int)
+ a.SetString("64375784358435883458348587", 0)
+ b.SetString("4789759874531", 0)
+
+ // 0) a raw zero value has 1 as denominator
+ zero := new(Rat)
+ one := NewInt(1)
+ if zero.Denom().Cmp(one) != 0 {
+ t.Errorf("0) got %s want %s", zero.Denom(), one)
+ }
+
+ // 1a) a zero value remains zero independent of denominator
+ x := new(Rat)
+ x.Denom().Set(new(Int).Neg(b))
+ if x.Cmp(zero) != 0 {
+ t.Errorf("1a) got %s want %s", x, zero)
+ }
+
+ // 1b) a zero value may have a denominator != 0 and != 1
+ x.Num().Set(a)
+ qab := new(Rat).SetFrac(a, b)
+ if x.Cmp(qab) != 0 {
+ t.Errorf("1b) got %s want %s", x, qab)
+ }
+
+ // 2a) an integral value becomes a fraction depending on denominator
+ x.SetFrac64(10, 2)
+ x.Denom().SetInt64(3)
+ q53 := NewRat(5, 3)
+ if x.Cmp(q53) != 0 {
+ t.Errorf("2a) got %s want %s", x, q53)
+ }
+
+ // 2b) an integral value becomes a fraction depending on denominator
+ x = NewRat(10, 2)
+ x.Denom().SetInt64(3)
+ if x.Cmp(q53) != 0 {
+ t.Errorf("2b) got %s want %s", x, q53)
+ }
+
+ // 3) changing the numerator/denominator of a Rat changes the Rat
+ x.SetFrac(a, b)
+ a = x.Num()
+ b = x.Denom()
+ a.SetInt64(5)
+ b.SetInt64(3)
+ if x.Cmp(q53) != 0 {
+ t.Errorf("3) got %s want %s", x, q53)
+ }
+}
+
+// Test inputs to Rat.SetString. The prefix "long:" causes the test
+// to be skipped in --test.short mode. (The threshold is about 500us.)
+var float64inputs = []string{
+ //
+ // Constants plundered from strconv/testfp.txt.
+ //
+
+ // Table 1: Stress Inputs for Conversion to 53-bit Binary, < 1/2 ULP
+ "5e+125",
+ "69e+267",
+ "999e-026",
+ "7861e-034",
+ "75569e-254",
+ "928609e-261",
+ "9210917e+080",
+ "84863171e+114",
+ "653777767e+273",
+ "5232604057e-298",
+ "27235667517e-109",
+ "653532977297e-123",
+ "3142213164987e-294",
+ "46202199371337e-072",
+ "231010996856685e-073",
+ "9324754620109615e+212",
+ "78459735791271921e+049",
+ "272104041512242479e+200",
+ "6802601037806061975e+198",
+ "20505426358836677347e-221",
+ "836168422905420598437e-234",
+ "4891559871276714924261e+222",
+
+ // Table 2: Stress Inputs for Conversion to 53-bit Binary, > 1/2 ULP
+ "9e-265",
+ "85e-037",
+ "623e+100",
+ "3571e+263",
+ "81661e+153",
+ "920657e-023",
+ "4603285e-024",
+ "87575437e-309",
+ "245540327e+122",
+ "6138508175e+120",
+ "83356057653e+193",
+ "619534293513e+124",
+ "2335141086879e+218",
+ "36167929443327e-159",
+ "609610927149051e-255",
+ "3743626360493413e-165",
+ "94080055902682397e-242",
+ "899810892172646163e+283",
+ "7120190517612959703e+120",
+ "25188282901709339043e-252",
+ "308984926168550152811e-052",
+ "6372891218502368041059e+064",
+
+ // Table 14: Stress Inputs for Conversion to 24-bit Binary, <1/2 ULP
+ "5e-20",
+ "67e+14",
+ "985e+15",
+ "7693e-42",
+ "55895e-16",
+ "996622e-44",
+ "7038531e-32",
+ "60419369e-46",
+ "702990899e-20",
+ "6930161142e-48",
+ "25933168707e+13",
+ "596428896559e+20",
+
+ // Table 15: Stress Inputs for Conversion to 24-bit Binary, >1/2 ULP
+ "3e-23",
+ "57e+18",
+ "789e-35",
+ "2539e-18",
+ "76173e+28",
+ "887745e-11",
+ "5382571e-37",
+ "82381273e-35",
+ "750486563e-38",
+ "3752432815e-39",
+ "75224575729e-45",
+ "459926601011e+15",
+
+ //
+ // Constants plundered from strconv/atof_test.go.
+ //
+
+ "0",
+ "1",
+ "+1",
+ "1e23",
+ "1E23",
+ "100000000000000000000000",
+ "1e-100",
+ "123456700",
+ "99999999999999974834176",
+ "100000000000000000000001",
+ "100000000000000008388608",
+ "100000000000000016777215",
+ "100000000000000016777216",
+ "-1",
+ "-0.1",
+ "-0", // NB: exception made for this input
+ "1e-20",
+ "625e-3",
+
+ // largest float64
+ "1.7976931348623157e308",
+ "-1.7976931348623157e308",
+ // next float64 - too large
+ "1.7976931348623159e308",
+ "-1.7976931348623159e308",
+ // the border is ...158079
+ // borderline - okay
+ "1.7976931348623158e308",
+ "-1.7976931348623158e308",
+ // borderline - too large
+ "1.797693134862315808e308",
+ "-1.797693134862315808e308",
+
+ // a little too large
+ "1e308",
+ "2e308",
+ "1e309",
+
+ // way too large
+ "1e310",
+ "-1e310",
+ "1e400",
+ "-1e400",
+ "long:1e400000",
+ "long:-1e400000",
+
+ // denormalized
+ "1e-305",
+ "1e-306",
+ "1e-307",
+ "1e-308",
+ "1e-309",
+ "1e-310",
+ "1e-322",
+ // smallest denormal
+ "5e-324",
+ "4e-324",
+ "3e-324",
+ // too small
+ "2e-324",
+ // way too small
+ "1e-350",
+ "long:1e-400000",
+ // way too small, negative
+ "-1e-350",
+ "long:-1e-400000",
+
+ // try to overflow exponent
+ // [Disabled: too slow and memory-hungry with rationals.]
+ // "1e-4294967296",
+ // "1e+4294967296",
+ // "1e-18446744073709551616",
+ // "1e+18446744073709551616",
+
+ // http://www.exploringbinary.com/java-hangs-when-converting-2-2250738585072012e-308/
+ "2.2250738585072012e-308",
+ // http://www.exploringbinary.com/php-hangs-on-numeric-value-2-2250738585072011e-308/
+
+ "2.2250738585072011e-308",
+
+ // A very large number (initially wrongly parsed by the fast algorithm).
+ "4.630813248087435e+307",
+
+ // A different kind of very large number.
+ "22.222222222222222",
+ "long:2." + strings.Repeat("2", 4000) + "e+1",
+
+ // Exactly halfway between 1 and math.Nextafter(1, 2).
+ // Round to even (down).
+ "1.00000000000000011102230246251565404236316680908203125",
+ // Slightly lower; still round down.
+ "1.00000000000000011102230246251565404236316680908203124",
+ // Slightly higher; round up.
+ "1.00000000000000011102230246251565404236316680908203126",
+ // Slightly higher, but you have to read all the way to the end.
+ "long:1.00000000000000011102230246251565404236316680908203125" + strings.Repeat("0", 10000) + "1",
+
+ // Smallest denormal, 2^(-1022-52)
+ "4.940656458412465441765687928682213723651e-324",
+ // Half of smallest denormal, 2^(-1022-53)
+ "2.470328229206232720882843964341106861825e-324",
+ // A little more than the exact half of smallest denormal
+ // 2^-1075 + 2^-1100. (Rounds to 1p-1074.)
+ "2.470328302827751011111470718709768633275e-324",
+ // The exact halfway between smallest normal and largest denormal:
+ // 2^-1022 - 2^-1075. (Rounds to 2^-1022.)
+ "2.225073858507201136057409796709131975935e-308",
+
+ "1152921504606846975", // 1<<60 - 1
+ "-1152921504606846975", // -(1<<60 - 1)
+ "1152921504606846977", // 1<<60 + 1
+ "-1152921504606846977", // -(1<<60 + 1)
+
+ "1/3",
+}
+
+func TestFloat64SpecialCases(t *testing.T) {
+ for _, input := range float64inputs {
+ if strings.HasPrefix(input, "long:") {
+ if testing.Short() {
+ continue
+ }
+ input = input[len("long:"):]
+ }
+
+ r, ok := new(Rat).SetString(input)
+ if !ok {
+ t.Errorf("Rat.SetString(%q) failed", input)
+ continue
+ }
+ f, exact := r.Float64()
+
+ // 1. Check string -> Rat -> float64 conversions are
+ // consistent with strconv.ParseFloat.
+ // Skip this check if the input uses "a/b" rational syntax.
+ if !strings.Contains(input, "/") {
+ e, _ := strconv.ParseFloat(input, 64)
+
+ // Careful: negative Rats too small for
+ // float64 become -0, but Rat obviously cannot
+ // preserve the sign from SetString("-0").
+ switch {
+ case math.Float64bits(e) == math.Float64bits(f):
+ // Ok: bitwise equal.
+ case f == 0 && r.Num().BitLen() == 0:
+ // Ok: Rat(0) is equivalent to both +/- float64(0).
+ default:
+ t.Errorf("strconv.ParseFloat(%q) = %g (%b), want %g (%b); delta=%g", input, e, e, f, f, f-e)
+ }
+ }
+
+ if !isFinite(f) {
+ continue
+ }
+
+ // 2. Check f is best approximation to r.
+ if !checkIsBestApprox(t, f, r) {
+ // Append context information.
+ t.Errorf("(input was %q)", input)
+ }
+
+ // 3. Check f->R->f roundtrip is non-lossy.
+ checkNonLossyRoundtrip(t, f)
+
+ // 4. Check exactness using slow algorithm.
+ if wasExact := new(Rat).SetFloat64(f).Cmp(r) == 0; wasExact != exact {
+ t.Errorf("Rat.SetString(%q).Float64().exact = %t, want %t", input, exact, wasExact)
+ }
+ }
+}
+
+func TestFloat64Distribution(t *testing.T) {
+ // Generate a distribution of (sign, mantissa, exp) values
+ // broader than the float64 range, and check Rat.Float64()
+ // always picks the closest float64 approximation.
+ var add = []int64{
+ 0,
+ 1,
+ 3,
+ 5,
+ 7,
+ 9,
+ 11,
+ }
+ var winc, einc = uint64(1), int(1) // soak test (~75s on x86-64)
+ if testing.Short() {
+ winc, einc = 10, 500 // quick test (~12ms on x86-64)
+ }
+
+ for _, sign := range "+-" {
+ for _, a := range add {
+ for wid := uint64(0); wid < 60; wid += winc {
+ b := int64(1<<wid + a)
+ if sign == '-' {
+ b = -b
+ }
+ for exp := -1100; exp < 1100; exp += einc {
+ num, den := NewInt(b), NewInt(1)
+ if exp > 0 {
+ num.Lsh(num, uint(exp))
+ } else {
+ den.Lsh(den, uint(-exp))
+ }
+ r := new(Rat).SetFrac(num, den)
+ f, _ := r.Float64()
+
+ if !checkIsBestApprox(t, f, r) {
+ // Append context information.
+ t.Errorf("(input was mantissa %#x, exp %d; f=%g (%b); f~%g; r=%v)",
+ b, exp, f, f, math.Ldexp(float64(b), exp), r)
+ }
+
+ checkNonLossyRoundtrip(t, f)
+ }
+ }
+ }
+ }
+}
+
+// TestFloat64NonFinite checks that SetFloat64 of a non-finite value
+// returns nil.
+func TestSetFloat64NonFinite(t *testing.T) {
+ for _, f := range []float64{math.NaN(), math.Inf(+1), math.Inf(-1)} {
+ var r Rat
+ if r2 := r.SetFloat64(f); r2 != nil {
+ t.Errorf("SetFloat64(%g) was %v, want nil", f, r2)
+ }
+ }
+}
+
+// checkNonLossyRoundtrip checks that a float->Rat->float roundtrip is
+// non-lossy for finite f.
+func checkNonLossyRoundtrip(t *testing.T, f float64) {
+ if !isFinite(f) {
+ return
+ }
+ r := new(Rat).SetFloat64(f)
+ if r == nil {
+ t.Errorf("Rat.SetFloat64(%g (%b)) == nil", f, f)
+ return
+ }
+ f2, exact := r.Float64()
+ if f != f2 || !exact {
+ t.Errorf("Rat.SetFloat64(%g).Float64() = %g (%b), %v, want %g (%b), %v; delta=%b",
+ f, f2, f2, exact, f, f, true, f2-f)
+ }
+}
+
+// delta returns the absolute difference between r and f.
+func delta(r *Rat, f float64) *Rat {
+ d := new(Rat).Sub(r, new(Rat).SetFloat64(f))
+ return d.Abs(d)
+}
+
+// checkIsBestApprox checks that f is the best possible float64
+// approximation of r.
+// Returns true on success.
+func checkIsBestApprox(t *testing.T, f float64, r *Rat) bool {
+ if math.Abs(f) >= math.MaxFloat64 {
+ // Cannot check +Inf, -Inf, nor the float next to them (MaxFloat64).
+ // But we have tests for these special cases.
+ return true
+ }
+
+ // r must be strictly between f0 and f1, the floats bracketing f.
+ f0 := math.Nextafter(f, math.Inf(-1))
+ f1 := math.Nextafter(f, math.Inf(+1))
+
+ // For f to be correct, r must be closer to f than to f0 or f1.
+ df := delta(r, f)
+ df0 := delta(r, f0)
+ df1 := delta(r, f1)
+ if df.Cmp(df0) > 0 {
+ t.Errorf("Rat(%v).Float64() = %g (%b), but previous float64 %g (%b) is closer", r, f, f, f0, f0)
+ return false
+ }
+ if df.Cmp(df1) > 0 {
+ t.Errorf("Rat(%v).Float64() = %g (%b), but next float64 %g (%b) is closer", r, f, f, f1, f1)
+ return false
+ }
+ if df.Cmp(df0) == 0 && !isEven(f) {
+ t.Errorf("Rat(%v).Float64() = %g (%b); halfway should have rounded to %g (%b) instead", r, f, f, f0, f0)
+ return false
+ }
+ if df.Cmp(df1) == 0 && !isEven(f) {
+ t.Errorf("Rat(%v).Float64() = %g (%b); halfway should have rounded to %g (%b) instead", r, f, f, f1, f1)
+ return false
+ }
+ return true
+}
+
+func isEven(f float64) bool { return math.Float64bits(f)&1 == 0 }
+
+func TestIsFinite(t *testing.T) {
+ finites := []float64{
+ 1.0 / 3,
+ 4891559871276714924261e+222,
+ math.MaxFloat64,
+ math.SmallestNonzeroFloat64,
+ -math.MaxFloat64,
+ -math.SmallestNonzeroFloat64,
+ }
+ for _, f := range finites {
+ if !isFinite(f) {
+ t.Errorf("!IsFinite(%g (%b))", f, f)
+ }
+ }
+ nonfinites := []float64{
+ math.NaN(),
+ math.Inf(-1),
+ math.Inf(+1),
+ }
+ for _, f := range nonfinites {
+ if isFinite(f) {
+ t.Errorf("IsFinite(%g, (%b))", f, f)
+ }
+ }
+}