diff options
Diffstat (limited to 'src/pkg/math/big')
-rw-r--r-- | src/pkg/math/big/arith_386.s | 18 | ||||
-rw-r--r-- | src/pkg/math/big/arith_amd64.s | 331 | ||||
-rw-r--r-- | src/pkg/math/big/arith_arm.s | 18 | ||||
-rw-r--r-- | src/pkg/math/big/arith_test.go | 86 | ||||
-rw-r--r-- | src/pkg/math/big/calibrate_test.go | 42 | ||||
-rw-r--r-- | src/pkg/math/big/gcd_test.go | 47 | ||||
-rw-r--r-- | src/pkg/math/big/int.go | 124 | ||||
-rw-r--r-- | src/pkg/math/big/int_test.go | 205 | ||||
-rw-r--r-- | src/pkg/math/big/nat.go | 309 | ||||
-rw-r--r-- | src/pkg/math/big/nat_test.go | 135 | ||||
-rw-r--r-- | src/pkg/math/big/rat.go | 269 | ||||
-rw-r--r-- | src/pkg/math/big/rat_test.go | 499 |
12 files changed, 1656 insertions, 427 deletions
diff --git a/src/pkg/math/big/arith_386.s b/src/pkg/math/big/arith_386.s index f1262c651..c62483317 100644 --- a/src/pkg/math/big/arith_386.s +++ b/src/pkg/math/big/arith_386.s @@ -29,7 +29,7 @@ TEXT ·addVV(SB),7,$0 MOVL z+0(FP), DI MOVL x+12(FP), SI MOVL y+24(FP), CX - MOVL n+4(FP), BP + MOVL z+4(FP), BP MOVL $0, BX // i = 0 MOVL $0, DX // c = 0 JMP E1 @@ -54,7 +54,7 @@ TEXT ·subVV(SB),7,$0 MOVL z+0(FP), DI MOVL x+12(FP), SI MOVL y+24(FP), CX - MOVL n+4(FP), BP + MOVL z+4(FP), BP MOVL $0, BX // i = 0 MOVL $0, DX // c = 0 JMP E2 @@ -78,7 +78,7 @@ TEXT ·addVW(SB),7,$0 MOVL z+0(FP), DI MOVL x+12(FP), SI MOVL y+24(FP), AX // c = y - MOVL n+4(FP), BP + MOVL z+4(FP), BP MOVL $0, BX // i = 0 JMP E3 @@ -100,7 +100,7 @@ TEXT ·subVW(SB),7,$0 MOVL z+0(FP), DI MOVL x+12(FP), SI MOVL y+24(FP), AX // c = y - MOVL n+4(FP), BP + MOVL z+4(FP), BP MOVL $0, BX // i = 0 JMP E4 @@ -120,7 +120,7 @@ E4: CMPL BX, BP // i < n // func shlVU(z, x []Word, s uint) (c Word) TEXT ·shlVU(SB),7,$0 - MOVL n+4(FP), BX // i = n + MOVL z+4(FP), BX // i = z SUBL $1, BX // i-- JL X8b // i < 0 (n <= 0) @@ -155,7 +155,7 @@ X8b: MOVL $0, c+28(FP) // func shrVU(z, x []Word, s uint) (c Word) TEXT ·shrVU(SB),7,$0 - MOVL n+4(FP), BP + MOVL z+4(FP), BP SUBL $1, BP // n-- JL X9b // n < 0 (n <= 0) @@ -196,7 +196,7 @@ TEXT ·mulAddVWW(SB),7,$0 MOVL x+12(FP), SI MOVL y+24(FP), BP MOVL r+28(FP), CX // c = r - MOVL n+4(FP), BX + MOVL z+4(FP), BX LEAL (DI)(BX*4), DI LEAL (SI)(BX*4), SI NEGL BX // i = -n @@ -222,7 +222,7 @@ TEXT ·addMulVVW(SB),7,$0 MOVL z+0(FP), DI MOVL x+12(FP), SI MOVL y+24(FP), BP - MOVL n+4(FP), BX + MOVL z+4(FP), BX LEAL (DI)(BX*4), DI LEAL (SI)(BX*4), SI NEGL BX // i = -n @@ -251,7 +251,7 @@ TEXT ·divWVW(SB),7,$0 MOVL xn+12(FP), DX // r = xn MOVL x+16(FP), SI MOVL y+28(FP), CX - MOVL n+4(FP), BX // i = n + MOVL z+4(FP), BX // i = z JMP E7 L7: MOVL (SI)(BX*4), AX diff --git a/src/pkg/math/big/arith_amd64.s b/src/pkg/math/big/arith_amd64.s index 54f647322..d85964502 100644 --- a/src/pkg/math/big/arith_amd64.s +++ b/src/pkg/math/big/arith_amd64.s @@ -5,7 +5,15 @@ // This file provides fast assembly versions for the elementary // arithmetic operations on vectors implemented in arith.go. -// TODO(gri) - experiment with unrolled loops for faster execution +// Literal instruction for MOVQ $0, CX. +// (MOVQ $0, reg is translated to XORQ reg, reg and clears CF.) +#define ZERO_CX BYTE $0x48; \ + BYTE $0xc7; \ + BYTE $0xc1; \ + BYTE $0x00; \ + BYTE $0x00; \ + BYTE $0x00; \ + BYTE $0x00 // func mulWW(x, y Word) (z1, z0 Word) TEXT ·mulWW(SB),7,$0 @@ -28,114 +36,231 @@ TEXT ·divWW(SB),7,$0 // func addVV(z, x, y []Word) (c Word) TEXT ·addVV(SB),7,$0 + MOVQ z+8(FP), DI + MOVQ x+24(FP), R8 + MOVQ y+48(FP), R9 MOVQ z+0(FP), R10 - MOVQ x+16(FP), R8 - MOVQ y+32(FP), R9 - MOVL n+8(FP), R11 - MOVQ $0, BX // i = 0 - MOVQ $0, DX // c = 0 - JMP E1 - -L1: MOVQ (R8)(BX*8), AX - RCRQ $1, DX - ADCQ (R9)(BX*8), AX - RCLQ $1, DX - MOVQ AX, (R10)(BX*8) - ADDL $1, BX // i++ -E1: CMPQ BX, R11 // i < n - JL L1 - - MOVQ DX, c+48(FP) + MOVQ $0, CX // c = 0 + MOVQ $0, SI // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUBQ $4, DI // n -= 4 + JL V1 // if n < 0 goto V1 + +U1: // n >= 0 + // regular loop body unrolled 4x + RCRQ $1, CX // CF = c + MOVQ 0(R8)(SI*8), R11 + MOVQ 8(R8)(SI*8), R12 + MOVQ 16(R8)(SI*8), R13 + MOVQ 24(R8)(SI*8), R14 + ADCQ 0(R9)(SI*8), R11 + ADCQ 8(R9)(SI*8), R12 + ADCQ 16(R9)(SI*8), R13 + ADCQ 24(R9)(SI*8), R14 + MOVQ R11, 0(R10)(SI*8) + MOVQ R12, 8(R10)(SI*8) + MOVQ R13, 16(R10)(SI*8) + MOVQ R14, 24(R10)(SI*8) + RCLQ $1, CX // c = CF + + ADDQ $4, SI // i += 4 + SUBQ $4, DI // n -= 4 + JGE U1 // if n >= 0 goto U1 + +V1: ADDQ $4, DI // n += 4 + JLE E1 // if n <= 0 goto E1 + +L1: // n > 0 + RCRQ $1, CX // CF = c + MOVQ 0(R8)(SI*8), R11 + ADCQ 0(R9)(SI*8), R11 + MOVQ R11, 0(R10)(SI*8) + RCLQ $1, CX // c = CF + + ADDQ $1, SI // i++ + SUBQ $1, DI // n-- + JG L1 // if n > 0 goto L1 + +E1: MOVQ CX, c+72(FP) // return c RET // func subVV(z, x, y []Word) (c Word) -// (same as addVV_s except for SBBQ instead of ADCQ and label names) +// (same as addVV except for SBBQ instead of ADCQ and label names) TEXT ·subVV(SB),7,$0 + MOVQ z+8(FP), DI + MOVQ x+24(FP), R8 + MOVQ y+48(FP), R9 MOVQ z+0(FP), R10 - MOVQ x+16(FP), R8 - MOVQ y+32(FP), R9 - MOVL n+8(FP), R11 - MOVQ $0, BX // i = 0 - MOVQ $0, DX // c = 0 - JMP E2 - -L2: MOVQ (R8)(BX*8), AX - RCRQ $1, DX - SBBQ (R9)(BX*8), AX - RCLQ $1, DX - MOVQ AX, (R10)(BX*8) - ADDL $1, BX // i++ - -E2: CMPQ BX, R11 // i < n - JL L2 - MOVQ DX, c+48(FP) + MOVQ $0, CX // c = 0 + MOVQ $0, SI // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUBQ $4, DI // n -= 4 + JL V2 // if n < 0 goto V2 + +U2: // n >= 0 + // regular loop body unrolled 4x + RCRQ $1, CX // CF = c + MOVQ 0(R8)(SI*8), R11 + MOVQ 8(R8)(SI*8), R12 + MOVQ 16(R8)(SI*8), R13 + MOVQ 24(R8)(SI*8), R14 + SBBQ 0(R9)(SI*8), R11 + SBBQ 8(R9)(SI*8), R12 + SBBQ 16(R9)(SI*8), R13 + SBBQ 24(R9)(SI*8), R14 + MOVQ R11, 0(R10)(SI*8) + MOVQ R12, 8(R10)(SI*8) + MOVQ R13, 16(R10)(SI*8) + MOVQ R14, 24(R10)(SI*8) + RCLQ $1, CX // c = CF + + ADDQ $4, SI // i += 4 + SUBQ $4, DI // n -= 4 + JGE U2 // if n >= 0 goto U2 + +V2: ADDQ $4, DI // n += 4 + JLE E2 // if n <= 0 goto E2 + +L2: // n > 0 + RCRQ $1, CX // CF = c + MOVQ 0(R8)(SI*8), R11 + SBBQ 0(R9)(SI*8), R11 + MOVQ R11, 0(R10)(SI*8) + RCLQ $1, CX // c = CF + + ADDQ $1, SI // i++ + SUBQ $1, DI // n-- + JG L2 // if n > 0 goto L2 + +E2: MOVQ CX, c+72(FP) // return c RET // func addVW(z, x []Word, y Word) (c Word) TEXT ·addVW(SB),7,$0 + MOVQ z+8(FP), DI + MOVQ x+24(FP), R8 + MOVQ y+48(FP), CX // c = y MOVQ z+0(FP), R10 - MOVQ x+16(FP), R8 - MOVQ y+32(FP), AX // c = y - MOVL n+8(FP), R11 - MOVQ $0, BX // i = 0 - JMP E3 -L3: ADDQ (R8)(BX*8), AX - MOVQ AX, (R10)(BX*8) - RCLQ $1, AX - ANDQ $1, AX - ADDL $1, BX // i++ - -E3: CMPQ BX, R11 // i < n - JL L3 - - MOVQ AX, c+40(FP) + MOVQ $0, SI // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUBQ $4, DI // n -= 4 + JL V3 // if n < 4 goto V3 + +U3: // n >= 0 + // regular loop body unrolled 4x + MOVQ 0(R8)(SI*8), R11 + MOVQ 8(R8)(SI*8), R12 + MOVQ 16(R8)(SI*8), R13 + MOVQ 24(R8)(SI*8), R14 + ADDQ CX, R11 + ZERO_CX + ADCQ $0, R12 + ADCQ $0, R13 + ADCQ $0, R14 + SETCS CX // c = CF + MOVQ R11, 0(R10)(SI*8) + MOVQ R12, 8(R10)(SI*8) + MOVQ R13, 16(R10)(SI*8) + MOVQ R14, 24(R10)(SI*8) + + ADDQ $4, SI // i += 4 + SUBQ $4, DI // n -= 4 + JGE U3 // if n >= 0 goto U3 + +V3: ADDQ $4, DI // n += 4 + JLE E3 // if n <= 0 goto E3 + +L3: // n > 0 + ADDQ 0(R8)(SI*8), CX + MOVQ CX, 0(R10)(SI*8) + ZERO_CX + RCLQ $1, CX // c = CF + + ADDQ $1, SI // i++ + SUBQ $1, DI // n-- + JG L3 // if n > 0 goto L3 + +E3: MOVQ CX, c+56(FP) // return c RET // func subVW(z, x []Word, y Word) (c Word) +// (same as addVW except for SUBQ/SBBQ instead of ADDQ/ADCQ and label names) TEXT ·subVW(SB),7,$0 + MOVQ z+8(FP), DI + MOVQ x+24(FP), R8 + MOVQ y+48(FP), CX // c = y MOVQ z+0(FP), R10 - MOVQ x+16(FP), R8 - MOVQ y+32(FP), AX // c = y - MOVL n+8(FP), R11 - MOVQ $0, BX // i = 0 - JMP E4 - -L4: MOVQ (R8)(BX*8), DX // TODO(gri) is there a reverse SUBQ? - SUBQ AX, DX - MOVQ DX, (R10)(BX*8) - RCLQ $1, AX - ANDQ $1, AX - ADDL $1, BX // i++ - -E4: CMPQ BX, R11 // i < n - JL L4 - - MOVQ AX, c+40(FP) + + MOVQ $0, SI // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUBQ $4, DI // n -= 4 + JL V4 // if n < 4 goto V4 + +U4: // n >= 0 + // regular loop body unrolled 4x + MOVQ 0(R8)(SI*8), R11 + MOVQ 8(R8)(SI*8), R12 + MOVQ 16(R8)(SI*8), R13 + MOVQ 24(R8)(SI*8), R14 + SUBQ CX, R11 + ZERO_CX + SBBQ $0, R12 + SBBQ $0, R13 + SBBQ $0, R14 + SETCS CX // c = CF + MOVQ R11, 0(R10)(SI*8) + MOVQ R12, 8(R10)(SI*8) + MOVQ R13, 16(R10)(SI*8) + MOVQ R14, 24(R10)(SI*8) + + ADDQ $4, SI // i += 4 + SUBQ $4, DI // n -= 4 + JGE U4 // if n >= 0 goto U4 + +V4: ADDQ $4, DI // n += 4 + JLE E4 // if n <= 0 goto E4 + +L4: // n > 0 + MOVQ 0(R8)(SI*8), R11 + SUBQ CX, R11 + MOVQ R11, 0(R10)(SI*8) + ZERO_CX + RCLQ $1, CX // c = CF + + ADDQ $1, SI // i++ + SUBQ $1, DI // n-- + JG L4 // if n > 0 goto L4 + +E4: MOVQ CX, c+56(FP) // return c RET // func shlVU(z, x []Word, s uint) (c Word) TEXT ·shlVU(SB),7,$0 - MOVL n+8(FP), BX // i = n - SUBL $1, BX // i-- + MOVQ z+8(FP), BX // i = z + SUBQ $1, BX // i-- JL X8b // i < 0 (n <= 0) // n > 0 MOVQ z+0(FP), R10 - MOVQ x+16(FP), R8 - MOVL s+32(FP), CX + MOVQ x+24(FP), R8 + MOVQ s+48(FP), CX MOVQ (R8)(BX*8), AX // w1 = x[n-1] MOVQ $0, DX SHLQ CX, DX:AX // w1>>ŝ - MOVQ DX, c+40(FP) + MOVQ DX, c+56(FP) - CMPL BX, $0 + CMPQ BX, $0 JLE X8a // i <= 0 // i > 0 @@ -143,7 +268,7 @@ L8: MOVQ AX, DX // w = w1 MOVQ -8(R8)(BX*8), AX // w1 = x[i-1] SHLQ CX, DX:AX // w<<s | w1>>ŝ MOVQ DX, (R10)(BX*8) // z[i] = w<<s | w1>>ŝ - SUBL $1, BX // i-- + SUBQ $1, BX // i-- JG L8 // i > 0 // i <= 0 @@ -151,24 +276,24 @@ X8a: SHLQ CX, AX // w1<<s MOVQ AX, (R10) // z[0] = w1<<s RET -X8b: MOVQ $0, c+40(FP) +X8b: MOVQ $0, c+56(FP) RET // func shrVU(z, x []Word, s uint) (c Word) TEXT ·shrVU(SB),7,$0 - MOVL n+8(FP), R11 - SUBL $1, R11 // n-- + MOVQ z+8(FP), R11 + SUBQ $1, R11 // n-- JL X9b // n < 0 (n <= 0) // n > 0 MOVQ z+0(FP), R10 - MOVQ x+16(FP), R8 - MOVL s+32(FP), CX + MOVQ x+24(FP), R8 + MOVQ s+48(FP), CX MOVQ (R8), AX // w1 = x[0] MOVQ $0, DX SHRQ CX, DX:AX // w1<<ŝ - MOVQ DX, c+40(FP) + MOVQ DX, c+56(FP) MOVQ $0, BX // i = 0 JMP E9 @@ -178,7 +303,7 @@ L9: MOVQ AX, DX // w = w1 MOVQ 8(R8)(BX*8), AX // w1 = x[i+1] SHRQ CX, DX:AX // w>>s | w1<<ŝ MOVQ DX, (R10)(BX*8) // z[i] = w>>s | w1<<ŝ - ADDL $1, BX // i++ + ADDQ $1, BX // i++ E9: CMPQ BX, R11 JL L9 // i < n-1 @@ -188,17 +313,17 @@ X9a: SHRQ CX, AX // w1>>s MOVQ AX, (R10)(R11*8) // z[n-1] = w1>>s RET -X9b: MOVQ $0, c+40(FP) +X9b: MOVQ $0, c+56(FP) RET // func mulAddVWW(z, x []Word, y, r Word) (c Word) TEXT ·mulAddVWW(SB),7,$0 MOVQ z+0(FP), R10 - MOVQ x+16(FP), R8 - MOVQ y+32(FP), R9 - MOVQ r+40(FP), CX // c = r - MOVL n+8(FP), R11 + MOVQ x+24(FP), R8 + MOVQ y+48(FP), R9 + MOVQ r+56(FP), CX // c = r + MOVQ z+8(FP), R11 MOVQ $0, BX // i = 0 JMP E5 @@ -208,21 +333,21 @@ L5: MOVQ (R8)(BX*8), AX ADCQ $0, DX MOVQ AX, (R10)(BX*8) MOVQ DX, CX - ADDL $1, BX // i++ + ADDQ $1, BX // i++ E5: CMPQ BX, R11 // i < n JL L5 - MOVQ CX, c+48(FP) + MOVQ CX, c+64(FP) RET // func addMulVVW(z, x []Word, y Word) (c Word) TEXT ·addMulVVW(SB),7,$0 MOVQ z+0(FP), R10 - MOVQ x+16(FP), R8 - MOVQ y+32(FP), R9 - MOVL n+8(FP), R11 + MOVQ x+24(FP), R8 + MOVQ y+48(FP), R9 + MOVQ z+8(FP), R11 MOVQ $0, BX // i = 0 MOVQ $0, CX // c = 0 JMP E6 @@ -234,41 +359,41 @@ L6: MOVQ (R8)(BX*8), AX ADDQ AX, (R10)(BX*8) ADCQ $0, DX MOVQ DX, CX - ADDL $1, BX // i++ + ADDQ $1, BX // i++ E6: CMPQ BX, R11 // i < n JL L6 - MOVQ CX, c+40(FP) + MOVQ CX, c+56(FP) RET // func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) TEXT ·divWVW(SB),7,$0 MOVQ z+0(FP), R10 - MOVQ xn+16(FP), DX // r = xn - MOVQ x+24(FP), R8 - MOVQ y+40(FP), R9 - MOVL n+8(FP), BX // i = n + MOVQ xn+24(FP), DX // r = xn + MOVQ x+32(FP), R8 + MOVQ y+56(FP), R9 + MOVQ z+8(FP), BX // i = z JMP E7 L7: MOVQ (R8)(BX*8), AX DIVQ R9 MOVQ AX, (R10)(BX*8) -E7: SUBL $1, BX // i-- +E7: SUBQ $1, BX // i-- JGE L7 // i >= 0 - MOVQ DX, r+48(FP) + MOVQ DX, r+64(FP) RET // func bitLen(x Word) (n int) TEXT ·bitLen(SB),7,$0 BSRQ x+0(FP), AX JZ Z1 - INCL AX - MOVL AX, n+8(FP) + ADDQ $1, AX + MOVQ AX, n+8(FP) RET -Z1: MOVL $0, n+8(FP) +Z1: MOVQ $0, n+8(FP) RET diff --git a/src/pkg/math/big/arith_arm.s b/src/pkg/math/big/arith_arm.s index dbf3360b5..64610f915 100644 --- a/src/pkg/math/big/arith_arm.s +++ b/src/pkg/math/big/arith_arm.s @@ -13,7 +13,7 @@ TEXT ·addVV(SB),7,$0 MOVW z+0(FP), R1 MOVW x+12(FP), R2 MOVW y+24(FP), R3 - MOVW n+4(FP), R4 + MOVW z+4(FP), R4 MOVW R4<<2, R4 ADD R1, R4 B E1 @@ -41,7 +41,7 @@ TEXT ·subVV(SB),7,$0 MOVW z+0(FP), R1 MOVW x+12(FP), R2 MOVW y+24(FP), R3 - MOVW n+4(FP), R4 + MOVW z+4(FP), R4 MOVW R4<<2, R4 ADD R1, R4 B E2 @@ -68,7 +68,7 @@ TEXT ·addVW(SB),7,$0 MOVW z+0(FP), R1 MOVW x+12(FP), R2 MOVW y+24(FP), R3 - MOVW n+4(FP), R4 + MOVW z+4(FP), R4 MOVW R4<<2, R4 ADD R1, R4 CMP R1, R4 @@ -102,7 +102,7 @@ TEXT ·subVW(SB),7,$0 MOVW z+0(FP), R1 MOVW x+12(FP), R2 MOVW y+24(FP), R3 - MOVW n+4(FP), R4 + MOVW z+4(FP), R4 MOVW R4<<2, R4 ADD R1, R4 CMP R1, R4 @@ -134,7 +134,7 @@ E4: // func shlVU(z, x []Word, s uint) (c Word) TEXT ·shlVU(SB),7,$0 - MOVW n+4(FP), R5 + MOVW z+4(FP), R5 CMP $0, R5 BEQ X7 @@ -183,7 +183,7 @@ X7: // func shrVU(z, x []Word, s uint) (c Word) TEXT ·shrVU(SB),7,$0 - MOVW n+4(FP), R5 + MOVW z+4(FP), R5 CMP $0, R5 BEQ X6 @@ -238,7 +238,7 @@ TEXT ·mulAddVWW(SB),7,$0 MOVW x+12(FP), R2 MOVW y+24(FP), R3 MOVW r+28(FP), R4 - MOVW n+4(FP), R5 + MOVW z+4(FP), R5 MOVW R5<<2, R5 ADD R1, R5 B E8 @@ -265,7 +265,7 @@ TEXT ·addMulVVW(SB),7,$0 MOVW z+0(FP), R1 MOVW x+12(FP), R2 MOVW y+24(FP), R3 - MOVW n+4(FP), R5 + MOVW z+4(FP), R5 MOVW R5<<2, R5 ADD R1, R5 MOVW $0, R4 @@ -314,7 +314,7 @@ TEXT ·mulWW(SB),7,$0 // func bitLen(x Word) (n int) TEXT ·bitLen(SB),7,$0 MOVW x+0(FP), R0 - WORD $0xe16f0f10 // CLZ R0, R0 (count leading zeros) + CLZ R0, R0 MOVW $32, R1 SUB.S R0, R1 MOVW R1, n+4(FP) diff --git a/src/pkg/math/big/arith_test.go b/src/pkg/math/big/arith_test.go index c7e3d284c..3615a659c 100644 --- a/src/pkg/math/big/arith_test.go +++ b/src/pkg/math/big/arith_test.go @@ -4,7 +4,10 @@ package big -import "testing" +import ( + "math/rand" + "testing" +) type funWW func(x, y, c Word) (z1, z0 Word) type argWW struct { @@ -100,6 +103,43 @@ func TestFunVV(t *testing.T) { } } +// Always the same seed for reproducible results. +var rnd = rand.New(rand.NewSource(0)) + +func rndW() Word { + return Word(rnd.Int63()<<1 | rnd.Int63n(2)) +} + +func rndV(n int) []Word { + v := make([]Word, n) + for i := range v { + v[i] = rndW() + } + return v +} + +func benchmarkFunVV(b *testing.B, f funVV, n int) { + x := rndV(n) + y := rndV(n) + z := make([]Word, n) + b.SetBytes(int64(n * _W)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + f(z, x, y) + } +} + +func BenchmarkAddVV_1(b *testing.B) { benchmarkFunVV(b, addVV, 1) } +func BenchmarkAddVV_2(b *testing.B) { benchmarkFunVV(b, addVV, 2) } +func BenchmarkAddVV_3(b *testing.B) { benchmarkFunVV(b, addVV, 3) } +func BenchmarkAddVV_4(b *testing.B) { benchmarkFunVV(b, addVV, 4) } +func BenchmarkAddVV_5(b *testing.B) { benchmarkFunVV(b, addVV, 5) } +func BenchmarkAddVV_1e1(b *testing.B) { benchmarkFunVV(b, addVV, 1e1) } +func BenchmarkAddVV_1e2(b *testing.B) { benchmarkFunVV(b, addVV, 1e2) } +func BenchmarkAddVV_1e3(b *testing.B) { benchmarkFunVV(b, addVV, 1e3) } +func BenchmarkAddVV_1e4(b *testing.B) { benchmarkFunVV(b, addVV, 1e4) } +func BenchmarkAddVV_1e5(b *testing.B) { benchmarkFunVV(b, addVV, 1e5) } + type funVW func(z, x []Word, y Word) (c Word) type argVW struct { z, x nat @@ -210,6 +250,28 @@ func TestFunVW(t *testing.T) { } } +func benchmarkFunVW(b *testing.B, f funVW, n int) { + x := rndV(n) + y := rndW() + z := make([]Word, n) + b.SetBytes(int64(n * _W)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + f(z, x, y) + } +} + +func BenchmarkAddVW_1(b *testing.B) { benchmarkFunVW(b, addVW, 1) } +func BenchmarkAddVW_2(b *testing.B) { benchmarkFunVW(b, addVW, 2) } +func BenchmarkAddVW_3(b *testing.B) { benchmarkFunVW(b, addVW, 3) } +func BenchmarkAddVW_4(b *testing.B) { benchmarkFunVW(b, addVW, 4) } +func BenchmarkAddVW_5(b *testing.B) { benchmarkFunVW(b, addVW, 5) } +func BenchmarkAddVW_1e1(b *testing.B) { benchmarkFunVW(b, addVW, 1e1) } +func BenchmarkAddVW_1e2(b *testing.B) { benchmarkFunVW(b, addVW, 1e2) } +func BenchmarkAddVW_1e3(b *testing.B) { benchmarkFunVW(b, addVW, 1e3) } +func BenchmarkAddVW_1e4(b *testing.B) { benchmarkFunVW(b, addVW, 1e4) } +func BenchmarkAddVW_1e5(b *testing.B) { benchmarkFunVW(b, addVW, 1e5) } + type funVWW func(z, x []Word, y, r Word) (c Word) type argVWW struct { z, x nat @@ -334,6 +396,28 @@ func TestMulAddWWW(t *testing.T) { } } +func benchmarkAddMulVVW(b *testing.B, n int) { + x := rndV(n) + y := rndW() + z := make([]Word, n) + b.SetBytes(int64(n * _W)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + addMulVVW(z, x, y) + } +} + +func BenchmarkAddMulVVW_1(b *testing.B) { benchmarkAddMulVVW(b, 1) } +func BenchmarkAddMulVVW_2(b *testing.B) { benchmarkAddMulVVW(b, 2) } +func BenchmarkAddMulVVW_3(b *testing.B) { benchmarkAddMulVVW(b, 3) } +func BenchmarkAddMulVVW_4(b *testing.B) { benchmarkAddMulVVW(b, 4) } +func BenchmarkAddMulVVW_5(b *testing.B) { benchmarkAddMulVVW(b, 5) } +func BenchmarkAddMulVVW_1e1(b *testing.B) { benchmarkAddMulVVW(b, 1e1) } +func BenchmarkAddMulVVW_1e2(b *testing.B) { benchmarkAddMulVVW(b, 1e2) } +func BenchmarkAddMulVVW_1e3(b *testing.B) { benchmarkAddMulVVW(b, 1e3) } +func BenchmarkAddMulVVW_1e4(b *testing.B) { benchmarkAddMulVVW(b, 1e4) } +func BenchmarkAddMulVVW_1e5(b *testing.B) { benchmarkAddMulVVW(b, 1e5) } + func testWordBitLen(t *testing.T, fname string, f func(Word) int) { for i := 0; i <= _W; i++ { x := Word(1) << uint(i-1) // i == 0 => x == 0 diff --git a/src/pkg/math/big/calibrate_test.go b/src/pkg/math/big/calibrate_test.go index efe1837bb..f69ffbf5c 100644 --- a/src/pkg/math/big/calibrate_test.go +++ b/src/pkg/math/big/calibrate_test.go @@ -21,15 +21,17 @@ import ( var calibrate = flag.Bool("calibrate", false, "run calibration test") -// measure returns the time to run f -func measure(f func()) time.Duration { - const N = 100 - start := time.Now() - for i := N; i > 0; i-- { - f() - } - stop := time.Now() - return stop.Sub(start) / N +func karatsubaLoad(b *testing.B) { + BenchmarkMul(b) +} + +// measureKaratsuba returns the time to run a Karatsuba-relevant benchmark +// given Karatsuba threshold th. +func measureKaratsuba(th int) time.Duration { + th, karatsubaThreshold = karatsubaThreshold, th + res := testing.Benchmark(karatsubaLoad) + karatsubaThreshold = th + return time.Duration(res.NsPerOp()) } func computeThresholds() { @@ -37,35 +39,33 @@ func computeThresholds() { fmt.Printf("(run repeatedly for good results)\n") // determine Tk, the work load execution time using basic multiplication - karatsubaThreshold = 1e9 // disable karatsuba - Tb := measure(benchmarkMulLoad) - fmt.Printf("Tb = %dns\n", Tb) + Tb := measureKaratsuba(1e9) // th == 1e9 => Karatsuba multiplication disabled + fmt.Printf("Tb = %10s\n", Tb) // thresholds - n := 8 // any lower values for the threshold lead to very slow multiplies + th := 4 th1 := -1 th2 := -1 var deltaOld time.Duration - for count := -1; count != 0; count-- { + for count := -1; count != 0 && th < 128; count-- { // determine Tk, the work load execution time using Karatsuba multiplication - karatsubaThreshold = n // enable karatsuba - Tk := measure(benchmarkMulLoad) + Tk := measureKaratsuba(th) // improvement over Tb delta := (Tb - Tk) * 100 / Tb - fmt.Printf("n = %3d Tk = %8dns %4d%%", n, Tk, delta) + fmt.Printf("th = %3d Tk = %10s %4d%%", th, Tk, delta) // determine break-even point if Tk < Tb && th1 < 0 { - th1 = n + th1 = th fmt.Print(" break-even point") } // determine diminishing return if 0 < delta && delta < deltaOld && th2 < 0 { - th2 = n + th2 = th fmt.Print(" diminishing return") } deltaOld = delta @@ -74,10 +74,10 @@ func computeThresholds() { // trigger counter if th1 >= 0 && th2 >= 0 && count < 0 { - count = 20 // this many extra measurements after we got both thresholds + count = 10 // this many extra measurements after we got both thresholds } - n++ + th++ } } diff --git a/src/pkg/math/big/gcd_test.go b/src/pkg/math/big/gcd_test.go new file mode 100644 index 000000000..c0b9f5830 --- /dev/null +++ b/src/pkg/math/big/gcd_test.go @@ -0,0 +1,47 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements a GCD benchmark. +// Usage: go test math/big -test.bench GCD + +package big + +import ( + "math/rand" + "testing" +) + +// randInt returns a pseudo-random Int in the range [1<<(size-1), (1<<size) - 1] +func randInt(r *rand.Rand, size uint) *Int { + n := new(Int).Lsh(intOne, size-1) + x := new(Int).Rand(r, n) + return x.Add(x, n) // make sure result > 1<<(size-1) +} + +func runGCD(b *testing.B, aSize, bSize uint) { + b.StopTimer() + var r = rand.New(rand.NewSource(1234)) + aa := randInt(r, aSize) + bb := randInt(r, bSize) + b.StartTimer() + for i := 0; i < b.N; i++ { + new(Int).GCD(nil, nil, aa, bb) + } +} + +func BenchmarkGCD10x10(b *testing.B) { runGCD(b, 10, 10) } +func BenchmarkGCD10x100(b *testing.B) { runGCD(b, 10, 100) } +func BenchmarkGCD10x1000(b *testing.B) { runGCD(b, 10, 1000) } +func BenchmarkGCD10x10000(b *testing.B) { runGCD(b, 10, 10000) } +func BenchmarkGCD10x100000(b *testing.B) { runGCD(b, 10, 100000) } +func BenchmarkGCD100x100(b *testing.B) { runGCD(b, 100, 100) } +func BenchmarkGCD100x1000(b *testing.B) { runGCD(b, 100, 1000) } +func BenchmarkGCD100x10000(b *testing.B) { runGCD(b, 100, 10000) } +func BenchmarkGCD100x100000(b *testing.B) { runGCD(b, 100, 100000) } +func BenchmarkGCD1000x1000(b *testing.B) { runGCD(b, 1000, 1000) } +func BenchmarkGCD1000x10000(b *testing.B) { runGCD(b, 1000, 10000) } +func BenchmarkGCD1000x100000(b *testing.B) { runGCD(b, 1000, 100000) } +func BenchmarkGCD10000x10000(b *testing.B) { runGCD(b, 10000, 10000) } +func BenchmarkGCD10000x100000(b *testing.B) { runGCD(b, 10000, 100000) } +func BenchmarkGCD100000x100000(b *testing.B) { runGCD(b, 100000, 100000) } diff --git a/src/pkg/math/big/int.go b/src/pkg/math/big/int.go index cd2cd0e2d..bf2fd2009 100644 --- a/src/pkg/math/big/int.go +++ b/src/pkg/math/big/int.go @@ -51,6 +51,13 @@ func (z *Int) SetInt64(x int64) *Int { return z } +// SetUint64 sets z to x and returns z. +func (z *Int) SetUint64(x uint64) *Int { + z.abs = z.abs.setUint64(uint64(x)) + z.neg = false + return z +} + // NewInt allocates and returns a new Int set to x. func NewInt(x int64) *Int { return new(Int).SetInt64(x) @@ -412,7 +419,7 @@ func (x *Int) Format(s fmt.State, ch rune) { if precisionSet { switch { case len(digits) < precision: - zeroes = precision - len(digits) // count of zero padding + zeroes = precision - len(digits) // count of zero padding case digits == "0" && precision == 0: return // print nothing if zero value (x == 0) and zero precision ("." or ".0") } @@ -519,6 +526,19 @@ func (x *Int) Int64() int64 { return v } +// Uint64 returns the uint64 representation of x. +// If x cannot be represented in an uint64, the result is undefined. +func (x *Int) Uint64() uint64 { + if len(x.abs) == 0 { + return 0 + } + v := uint64(x.abs[0]) + if _W == 32 && len(x.abs) > 1 { + v |= uint64(x.abs[1]) << 32 + } + return v +} + // SetString sets z to the value of s, interpreted in the given base, // and returns z and a boolean indicating success. If SetString fails, // the value of z is undefined but the returned value is nil. @@ -561,19 +581,18 @@ func (x *Int) BitLen() int { return x.abs.bitLen() } -// Exp sets z = x**y mod m and returns z. If m is nil, z = x**y. +// Exp sets z = x**y mod |m| (i.e. the sign of m is ignored), and returns z. +// If y <= 0, the result is 1; if m == nil or m == 0, z = x**y. // See Knuth, volume 2, section 4.6.3. func (z *Int) Exp(x, y, m *Int) *Int { if y.neg || len(y.abs) == 0 { - neg := x.neg - z.SetInt64(1) - z.neg = neg - return z + return z.SetInt64(1) } + // y > 0 var mWords nat if m != nil { - mWords = m.abs + mWords = m.abs // m.abs may be nil for m == 0 } z.abs = z.abs.expNN(x.abs, y.abs, mWords) @@ -581,12 +600,12 @@ func (z *Int) Exp(x, y, m *Int) *Int { return z } -// GCD sets z to the greatest common divisor of a and b, which must be -// positive numbers, and returns z. +// GCD sets z to the greatest common divisor of a and b, which both must +// be > 0, and returns z. // If x and y are not nil, GCD sets x and y such that z = a*x + b*y. -// If either a or b is not positive, GCD sets z = x = y = 0. +// If either a or b is <= 0, GCD sets z = x = y = 0. func (z *Int) GCD(x, y, a, b *Int) *Int { - if a.neg || b.neg { + if a.Sign() <= 0 || b.Sign() <= 0 { z.SetInt64(0) if x != nil { x.SetInt64(0) @@ -596,6 +615,9 @@ func (z *Int) GCD(x, y, a, b *Int) *Int { } return z } + if x == nil && y == nil { + return z.binaryGCD(a, b) + } A := new(Int).Set(a) B := new(Int).Set(b) @@ -640,6 +662,63 @@ func (z *Int) GCD(x, y, a, b *Int) *Int { return z } +// binaryGCD sets z to the greatest common divisor of a and b, which both must +// be > 0, and returns z. +// See Knuth, The Art of Computer Programming, Vol. 2, Section 4.5.2, Algorithm B. +func (z *Int) binaryGCD(a, b *Int) *Int { + u := z + v := new(Int) + + // use one Euclidean iteration to ensure that u and v are approx. the same size + switch { + case len(a.abs) > len(b.abs): + u.Set(b) + v.Rem(a, b) + case len(a.abs) < len(b.abs): + u.Set(a) + v.Rem(b, a) + default: + u.Set(a) + v.Set(b) + } + + // v might be 0 now + if len(v.abs) == 0 { + return u + } + // u > 0 && v > 0 + + // determine largest k such that u = u' << k, v = v' << k + k := u.abs.trailingZeroBits() + if vk := v.abs.trailingZeroBits(); vk < k { + k = vk + } + u.Rsh(u, k) + v.Rsh(v, k) + + // determine t (we know that u > 0) + t := new(Int) + if u.abs[0]&1 != 0 { + // u is odd + t.Neg(v) + } else { + t.Set(u) + } + + for len(t.abs) > 0 { + // reduce t + t.Rsh(t, t.abs.trailingZeroBits()) + if t.neg { + v.Neg(t) + } else { + u.Set(t) + } + t.Sub(u, v) + } + + return u.Lsh(u, k) +} + // ProbablyPrime performs n Miller-Rabin tests to check whether x is prime. // If it returns true, x is prime with probability 1 - 1/4^n. // If it returns false, x is not prime. @@ -697,6 +776,13 @@ func (z *Int) Rsh(x *Int, n uint) *Int { // Bit returns the value of the i'th bit of x. That is, it // returns (x>>i)&1. The bit index i must be >= 0. func (x *Int) Bit(i int) uint { + if i == 0 { + // optimization for common case: odd/even test of x + if len(x.abs) > 0 { + return uint(x.abs[0] & 1) // bit 0 is same for -x + } + return 0 + } if i < 0 { panic("negative bit index") } @@ -894,3 +980,19 @@ func (z *Int) GobDecode(buf []byte) error { z.abs = z.abs.setBytes(buf[1:]) return nil } + +// MarshalJSON implements the json.Marshaler interface. +func (x *Int) MarshalJSON() ([]byte, error) { + // TODO(gri): get rid of the []byte/string conversions + return []byte(x.String()), nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (z *Int) UnmarshalJSON(x []byte) error { + // TODO(gri): get rid of the []byte/string conversions + _, ok := z.SetString(string(x), 0) + if !ok { + return fmt.Errorf("math/big: cannot unmarshal %s into a *big.Int", x) + } + return nil +} diff --git a/src/pkg/math/big/int_test.go b/src/pkg/math/big/int_test.go index 9700a9b5a..6c981e775 100644 --- a/src/pkg/math/big/int_test.go +++ b/src/pkg/math/big/int_test.go @@ -8,6 +8,7 @@ import ( "bytes" "encoding/gob" "encoding/hex" + "encoding/json" "fmt" "math/rand" "testing" @@ -642,7 +643,7 @@ func TestSetBytes(t *testing.T) { func checkBytes(b []byte) bool { b2 := new(Int).SetBytes(b).Bytes() - return bytes.Compare(b, b2) == 0 + return bytes.Equal(b, b2) } func TestBytes(t *testing.T) { @@ -766,8 +767,10 @@ var expTests = []struct { x, y, m string out string }{ + {"5", "-7", "", "1"}, + {"-5", "-7", "", "1"}, {"5", "0", "", "1"}, - {"-5", "0", "", "-1"}, + {"-5", "0", "", "1"}, {"5", "1", "", "5"}, {"-5", "1", "", "-5"}, {"-2", "3", "2", "0"}, @@ -778,6 +781,7 @@ var expTests = []struct { {"0x8000000000000000", "3", "6719", "5447"}, {"0x8000000000000000", "1000", "6719", "1603"}, {"0x8000000000000000", "1000000", "6719", "3199"}, + {"0x8000000000000000", "-1000000", "6719", "1"}, { "2938462938472983472983659726349017249287491026512746239764525612965293865296239471239874193284792387498274256129746192347", "298472983472983471903246121093472394872319615612417471234712061", @@ -806,25 +810,33 @@ func TestExp(t *testing.T) { continue } - z := y.Exp(x, y, m) - if !isNormalized(z) { - t.Errorf("#%d: %v is not normalized", i, *z) + z1 := new(Int).Exp(x, y, m) + if !isNormalized(z1) { + t.Errorf("#%d: %v is not normalized", i, *z1) } - if z.Cmp(out) != 0 { - t.Errorf("#%d: got %s want %s", i, z, out) + if z1.Cmp(out) != 0 { + t.Errorf("#%d: got %s want %s", i, z1, out) + } + + if m == nil { + // the result should be the same as for m == 0; + // specifically, there should be no div-zero panic + m = &Int{abs: nat{}} // m != nil && len(m.abs) == 0 + z2 := new(Int).Exp(x, y, m) + if z2.Cmp(z1) != 0 { + t.Errorf("#%d: got %s want %s", i, z1, z2) + } } } } func checkGcd(aBytes, bBytes []byte) bool { - a := new(Int).SetBytes(aBytes) - b := new(Int).SetBytes(bBytes) - x := new(Int) y := new(Int) - d := new(Int) + a := new(Int).SetBytes(aBytes) + b := new(Int).SetBytes(bBytes) - d.GCD(x, y, a, b) + d := new(Int).GCD(x, y, a, b) x.Mul(x, a) y.Mul(y, b) x.Add(x, y) @@ -833,32 +845,70 @@ func checkGcd(aBytes, bBytes []byte) bool { } var gcdTests = []struct { - a, b int64 - d, x, y int64 + d, x, y, a, b string }{ - {120, 23, 1, -9, 47}, -} - -func TestGcd(t *testing.T) { - for i, test := range gcdTests { - a := NewInt(test.a) - b := NewInt(test.b) + // a <= 0 || b <= 0 + {"0", "0", "0", "0", "0"}, + {"0", "0", "0", "0", "7"}, + {"0", "0", "0", "11", "0"}, + {"0", "0", "0", "-77", "35"}, + {"0", "0", "0", "64515", "-24310"}, + {"0", "0", "0", "-64515", "-24310"}, + + {"1", "-9", "47", "120", "23"}, + {"7", "1", "-2", "77", "35"}, + {"935", "-3", "8", "64515", "24310"}, + {"935000000000000000", "-3", "8", "64515000000000000000", "24310000000000000000"}, + {"1", "-221", "22059940471369027483332068679400581064239780177629666810348940098015901108344", "98920366548084643601728869055592650835572950932266967461790948584315647051443", "991"}, + + // test early exit (after one Euclidean iteration) in binaryGCD + {"1", "", "", "1", "98920366548084643601728869055592650835572950932266967461790948584315647051443"}, +} + +func testGcd(t *testing.T, d, x, y, a, b *Int) { + var X *Int + if x != nil { + X = new(Int) + } + var Y *Int + if y != nil { + Y = new(Int) + } - x := new(Int) - y := new(Int) - d := new(Int) + D := new(Int).GCD(X, Y, a, b) + if D.Cmp(d) != 0 { + t.Errorf("GCD(%s, %s): got d = %s, want %s", a, b, D, d) + } + if x != nil && X.Cmp(x) != 0 { + t.Errorf("GCD(%s, %s): got x = %s, want %s", a, b, X, x) + } + if y != nil && Y.Cmp(y) != 0 { + t.Errorf("GCD(%s, %s): got y = %s, want %s", a, b, Y, y) + } - expectedX := NewInt(test.x) - expectedY := NewInt(test.y) - expectedD := NewInt(test.d) + // binaryGCD requires a > 0 && b > 0 + if a.Sign() <= 0 || b.Sign() <= 0 { + return + } - d.GCD(x, y, a, b) + D.binaryGCD(a, b) + if D.Cmp(d) != 0 { + t.Errorf("binaryGcd(%s, %s): got d = %s, want %s", a, b, D, d) + } +} - if expectedX.Cmp(x) != 0 || - expectedY.Cmp(y) != 0 || - expectedD.Cmp(d) != 0 { - t.Errorf("#%d got (%s %s %s) want (%s %s %s)", i, x, y, d, expectedX, expectedY, expectedD) - } +func TestGcd(t *testing.T) { + for _, test := range gcdTests { + d, _ := new(Int).SetString(test.d, 0) + x, _ := new(Int).SetString(test.x, 0) + y, _ := new(Int).SetString(test.y, 0) + a, _ := new(Int).SetString(test.a, 0) + b, _ := new(Int).SetString(test.b, 0) + + testGcd(t, d, nil, nil, a, b) + testGcd(t, d, x, nil, a, b) + testGcd(t, d, nil, y, a, b) + testGcd(t, d, x, y, a, b) } quick.Check(checkGcd, nil) @@ -1085,6 +1135,36 @@ func TestInt64(t *testing.T) { } } +var uint64Tests = []uint64{ + 0, + 1, + 4294967295, + 4294967296, + 8589934591, + 8589934592, + 9223372036854775807, + 9223372036854775808, + 18446744073709551615, // 1<<64 - 1 +} + +func TestUint64(t *testing.T) { + in := new(Int) + for i, testVal := range uint64Tests { + in.SetUint64(testVal) + out := in.Uint64() + + if out != testVal { + t.Errorf("#%d got %d want %d", i, out, testVal) + } + + str := fmt.Sprint(testVal) + strOut := in.String() + if strOut != str { + t.Errorf("#%d.String got %s want %s", i, strOut, str) + } + } +} + var bitwiseTests = []struct { x, y string and, or, xor, andNot string @@ -1368,8 +1448,12 @@ func TestModInverse(t *testing.T) { } } -// used by TestIntGobEncoding and TestRatGobEncoding -var gobEncodingTests = []string{ +var encodingTests = []string{ + "-539345864568634858364538753846587364875430589374589", + "-678645873", + "-100", + "-2", + "-1", "0", "1", "2", @@ -1383,26 +1467,37 @@ func TestIntGobEncoding(t *testing.T) { var medium bytes.Buffer enc := gob.NewEncoder(&medium) dec := gob.NewDecoder(&medium) - for i, test := range gobEncodingTests { - for j := 0; j < 2; j++ { - medium.Reset() // empty buffer for each test case (in case of failures) - stest := test - if j != 0 { - // negative numbers - stest = "-" + test - } - var tx Int - tx.SetString(stest, 10) - if err := enc.Encode(&tx); err != nil { - t.Errorf("#%d%c: encoding failed: %s", i, 'a'+j, err) - } - var rx Int - if err := dec.Decode(&rx); err != nil { - t.Errorf("#%d%c: decoding failed: %s", i, 'a'+j, err) - } - if rx.Cmp(&tx) != 0 { - t.Errorf("#%d%c: transmission failed: got %s want %s", i, 'a'+j, &rx, &tx) - } + for _, test := range encodingTests { + medium.Reset() // empty buffer for each test case (in case of failures) + var tx Int + tx.SetString(test, 10) + if err := enc.Encode(&tx); err != nil { + t.Errorf("encoding of %s failed: %s", &tx, err) + } + var rx Int + if err := dec.Decode(&rx); err != nil { + t.Errorf("decoding of %s failed: %s", &tx, err) + } + if rx.Cmp(&tx) != 0 { + t.Errorf("transmission of %s failed: got %s want %s", &tx, &rx, &tx) + } + } +} + +func TestIntJSONEncoding(t *testing.T) { + for _, test := range encodingTests { + var tx Int + tx.SetString(test, 10) + b, err := json.Marshal(&tx) + if err != nil { + t.Errorf("marshaling of %s failed: %s", &tx, err) + } + var rx Int + if err := json.Unmarshal(b, &rx); err != nil { + t.Errorf("unmarshaling of %s failed: %s", &tx, err) + } + if rx.Cmp(&tx) != 0 { + t.Errorf("JSON encoding of %s failed: got %s want %s", &tx, &rx, &tx) } } } diff --git a/src/pkg/math/big/nat.go b/src/pkg/math/big/nat.go index eaa6ff066..9d09f97b7 100644 --- a/src/pkg/math/big/nat.go +++ b/src/pkg/math/big/nat.go @@ -236,7 +236,7 @@ func karatsubaSub(z, x nat, n int) { // Operands that are shorter than karatsubaThreshold are multiplied using // "grade school" multiplication; for longer operands the Karatsuba algorithm // is used. -var karatsubaThreshold int = 32 // computed by calibrate.go +var karatsubaThreshold int = 40 // computed by calibrate.go // karatsuba multiplies x and y and leaves the result in z. // Both x and y must have the same length n and n must be a @@ -342,7 +342,7 @@ func alias(x, y nat) bool { return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] } -// addAt implements z += x*(1<<(_W*i)); z must be long enough. +// addAt implements z += x<<(_W*i); z must be long enough. // (we don't use nat.add because we need z to stay the same // slice, and we don't need to normalize z after each addition) func addAt(z, x nat, i int) { @@ -396,7 +396,7 @@ func (z nat) mul(x, y nat) nat { } // use basic multiplication if the numbers are small - if n < karatsubaThreshold || n < 2 { + if n < karatsubaThreshold { z = z.make(m + n) basicMul(z, x, y) return z.norm() @@ -405,8 +405,8 @@ func (z nat) mul(x, y nat) nat { // determine Karatsuba length k such that // - // x = x1*b + x0 - // y = y1*b + y0 (and k <= len(y), which implies k <= len(x)) + // x = xh*b + x0 (0 <= x0 < b) + // y = yh*b + y0 (0 <= y0 < b) // b = 1<<(_W*k) ("base" of digits xi, yi) // k := karatsubaLen(n) @@ -417,27 +417,44 @@ func (z nat) mul(x, y nat) nat { y0 := y[0:k] // y0 is not normalized z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y karatsuba(z, x0, y0) - z = z[0 : m+n] // z has final length but may be incomplete, upper portion is garbage + z = z[0 : m+n] // z has final length but may be incomplete + z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) - // If x1 and/or y1 are not 0, add missing terms to z explicitly: + // If xh != 0 or yh != 0, add the missing terms to z. For + // + // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) + // yh = y1*b (0 <= y1 < b) + // + // the missing terms are + // + // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 // - // m+n 2*k 0 - // z = [ ... | x0*y0 ] - // + [ x1*y1 ] - // + [ x1*y0 ] - // + [ x0*y1 ] + // since all the yi for i > 1 are 0 by choice of k: If any of them + // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would + // be a larger valid threshold contradicting the assumption about k. // if k < n || m != n { - x1 := x[k:] // x1 is normalized because x is - y1 := y[k:] // y1 is normalized because y is var t nat - t = t.mul(x1, y1) - copy(z[2*k:], t) - z[2*k+len(t):].clear() // upper portion of z is garbage - t = t.mul(x1, y0.norm()) - addAt(z, t, k) - t = t.mul(x0.norm(), y1) + + // add x0*y1*b + x0 := x0.norm() + y1 := y[k:] // y1 is normalized because y is + t = t.mul(x0, y1) // update t so we don't lose t's underlying array addAt(z, t, k) + + // add xi*y0<<i, xi*y1*b<<(i+k) + y0 := y0.norm() + for i := k; i < len(x); i += k { + xi := x[i:] + if len(xi) > k { + xi = xi[:k] + } + xi = xi.norm() + t = t.mul(xi, y0) + addAt(z, t, i) + t = t.mul(xi, y1) + addAt(z, t, i+k) + } } return z.norm() @@ -493,14 +510,9 @@ func (z nat) div(z2, u, v nat) (q, r nat) { } if len(v) == 1 { - var rprime Word - q, rprime = z.divW(u, v[0]) - if rprime > 0 { - r = z2.make(1) - r[0] = rprime - } else { - r = z2.make(0) - } + var r2 Word + q, r2 = z.divW(u, v[0]) + r = z2.setWord(r2) return } @@ -740,7 +752,7 @@ func (x nat) string(charset string) string { // convert power of two and non power of two bases separately if b == b&-b { // shift is base-b digit size in bits - shift := uint(trailingZeroBits(b)) // shift > 0 because b >= 2 + shift := trailingZeroBits(b) // shift > 0 because b >= 2 mask := Word(1)<<shift - 1 w := x[0] nbits := uint(_W) // number of unprocessed bits in w @@ -814,18 +826,18 @@ func (x nat) string(charset string) string { // Convert words of q to base b digits in s. If q is large, it is recursively "split in half" // by nat/nat division using tabulated divisors. Otherwise, it is converted iteratively using -// repeated nat/Word divison. +// repeated nat/Word division. // -// The iterative method processes n Words by n divW() calls, each of which visits every Word in the -// incrementally shortened q for a total of n + (n-1) + (n-2) ... + 2 + 1, or n(n+1)/2 divW()'s. -// Recursive conversion divides q by its approximate square root, yielding two parts, each half +// The iterative method processes n Words by n divW() calls, each of which visits every Word in the +// incrementally shortened q for a total of n + (n-1) + (n-2) ... + 2 + 1, or n(n+1)/2 divW()'s. +// Recursive conversion divides q by its approximate square root, yielding two parts, each half // the size of q. Using the iterative method on both halves means 2 * (n/2)(n/2 + 1)/2 divW()'s // plus the expensive long div(). Asymptotically, the ratio is favorable at 1/2 the divW()'s, and -// is made better by splitting the subblocks recursively. Best is to split blocks until one more -// split would take longer (because of the nat/nat div()) than the twice as many divW()'s of the -// iterative approach. This threshold is represented by leafSize. Benchmarking of leafSize in the -// range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and -// ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for +// is made better by splitting the subblocks recursively. Best is to split blocks until one more +// split would take longer (because of the nat/nat div()) than the twice as many divW()'s of the +// iterative approach. This threshold is represented by leafSize. Benchmarking of leafSize in the +// range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and +// ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for // specific hardware. // func (q nat) convertWords(s []byte, charset string, b Word, ndigits int, bb Word, table []divisor) { @@ -908,8 +920,10 @@ type divisor struct { ndigits int // digit length of divisor in terms of output base digits } -var cacheBase10 [64]divisor // cached divisors for base 10 -var cacheLock sync.Mutex // protects cacheBase10 +var cacheBase10 struct { + sync.Mutex + table [64]divisor // cached divisors for base 10 +} // expWW computes x**y func (z nat) expWW(x, y Word) nat { @@ -925,34 +939,28 @@ func divisors(m int, b Word, ndigits int, bb Word) []divisor { // determine k where (bb**leafSize)**(2**k) >= sqrt(x) k := 1 - for words := leafSize; words < m>>1 && k < len(cacheBase10); words <<= 1 { + for words := leafSize; words < m>>1 && k < len(cacheBase10.table); words <<= 1 { k++ } - // create new table of divisors or extend and reuse existing table as appropriate - var table []divisor - var cached bool - switch b { - case 10: - table = cacheBase10[0:k] // reuse old table for this conversion - cached = true - default: - table = make([]divisor, k) // new table for this conversion + // reuse and extend existing table of divisors or create new table as appropriate + var table []divisor // for b == 10, table overlaps with cacheBase10.table + if b == 10 { + cacheBase10.Lock() + table = cacheBase10.table[0:k] // reuse old table for this conversion + } else { + table = make([]divisor, k) // create new table for this conversion } // extend table if table[k-1].ndigits == 0 { - if cached { - cacheLock.Lock() // begin critical section - } - // add new entries as needed var larger nat for i := 0; i < k; i++ { if table[i].ndigits == 0 { if i == 0 { - table[i].bbb = nat(nil).expWW(bb, Word(leafSize)) - table[i].ndigits = ndigits * leafSize + table[0].bbb = nat(nil).expWW(bb, Word(leafSize)) + table[0].ndigits = ndigits * leafSize } else { table[i].bbb = nat(nil).mul(table[i-1].bbb, table[i-1].bbb) table[i].ndigits = 2 * table[i-1].ndigits @@ -968,10 +976,10 @@ func divisors(m int, b Word, ndigits int, bb Word) []divisor { table[i].nbits = table[i].bbb.bitLen() } } + } - if cached { - cacheLock.Unlock() // end critical section - } + if b == 10 { + cacheBase10.Unlock() } return table @@ -993,10 +1001,9 @@ var deBruijn64Lookup = []byte{ 54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6, } -// trailingZeroBits returns the number of consecutive zero bits on the right -// side of the given Word. -// See Knuth, volume 4, section 7.3.1 -func trailingZeroBits(x Word) int { +// trailingZeroBits returns the number of consecutive least significant zero +// bits of x. +func trailingZeroBits(x Word) uint { // x & -x leaves only the right-most bit set in the word. Let k be the // index of that bit. Since only a single bit is set, the value is two // to the power of k. Multiplying by a power of two is equivalent to @@ -1005,18 +1012,33 @@ func trailingZeroBits(x Word) int { // Therefore, if we have a left shifted version of this constant we can // find by how many bits it was shifted by looking at which six bit // substring ended up at the top of the word. + // (Knuth, volume 4, section 7.3.1) switch _W { case 32: - return int(deBruijn32Lookup[((x&-x)*deBruijn32)>>27]) + return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27]) case 64: - return int(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) + return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) default: - panic("Unknown word size") + panic("unknown word size") } return 0 } +// trailingZeroBits returns the number of consecutive least significant zero +// bits of x. +func (x nat) trailingZeroBits() uint { + if len(x) == 0 { + return 0 + } + var i uint + for x[i] == 0 { + i++ + } + // x[i] != 0 + return i*_W + trailingZeroBits(x[i]) +} + // z = x << s func (z nat) shl(x nat, s uint) nat { m := len(x) @@ -1169,29 +1191,6 @@ func (x nat) modW(d Word) (r Word) { return divWVW(q, 0, x, d) } -// powersOfTwoDecompose finds q and k with x = q * 1<<k and q is odd, or q and k are 0. -func (x nat) powersOfTwoDecompose() (q nat, k int) { - if len(x) == 0 { - return x, 0 - } - - // One of the words must be non-zero by definition, - // so this loop will terminate with i < len(x), and - // i is the number of 0 words. - i := 0 - for x[i] == 0 { - i++ - } - n := trailingZeroBits(x[i]) // x[i] != 0 - - q = make(nat, len(x)-i) - shrVU(q, x[i:], uint(n)) - - q = q.norm() - k = i*_W + n - return -} - // random creates a random integer in [0..limit), using the space in z if // possible. n is the bit length of limit. func (z nat) random(rand *rand.Rand, limit nat, n int) nat { @@ -1207,17 +1206,19 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat { mask := Word((1 << bitLengthOfMSW) - 1) for { - for i := range z { - switch _W { - case 32: + switch _W { + case 32: + for i := range z { z[i] = Word(rand.Uint32()) - case 64: + } + case 64: + for i := range z { z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32 } + default: + panic("unknown word size") } - z[len(limit)-1] &= mask - if z.cmp(limit) < 0 { break } @@ -1226,11 +1227,11 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat { return z.norm() } -// If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It -// reuses the storage of z if possible. +// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; +// otherwise it sets z to x**y. The result is the value of z. func (z nat) expNN(x, y, m nat) nat { if alias(z, x) || alias(z, y) { - // We cannot allow in place modification of x or y. + // We cannot allow in-place modification of x or y. z = nil } @@ -1239,15 +1240,24 @@ func (z nat) expNN(x, y, m nat) nat { z[0] = 1 return z } + // y > 0 - if m != nil { + if len(m) != 0 { // We likely end up being as long as the modulus. z = z.make(len(m)) } z = z.set(x) - v := y[len(y)-1] - // It's invalid for the most significant word to be zero, therefore we - // will find a one bit. + + // If the base is non-trivial and the exponent is large, we use + // 4-bit, windowed exponentiation. This involves precomputing 14 values + // (x^2...x^15) but then reduces the number of multiply-reduces by a + // third. Even for a 32-bit exponent, this reduces the number of + // operations. + if len(x) > 1 && len(y) > 1 && len(m) > 0 { + return z.expNNWindowed(x, y, m) + } + + v := y[len(y)-1] // v > 0 because y is normalized and y > 0 shift := leadingZeros(v) + 1 v <<= shift var q nat @@ -1259,15 +1269,21 @@ func (z nat) expNN(x, y, m nat) nat { // we also multiply by x, thus adding one to the power. w := _W - int(shift) + // zz and r are used to avoid allocating in mul and div as + // otherwise the arguments would alias. + var zz, r nat for j := 0; j < w; j++ { - z = z.mul(z, z) + zz = zz.mul(z, z) + zz, z = z, zz if v&mask != 0 { - z = z.mul(z, x) + zz = zz.mul(z, x) + zz, z = z, zz } - if m != nil { - q, z = q.div(z, z, m) + if len(m) != 0 { + zz, r = zz.div(r, z, m) + zz, r, q, z = q, z, zz, r } v <<= 1 @@ -1277,14 +1293,17 @@ func (z nat) expNN(x, y, m nat) nat { v = y[i] for j := 0; j < _W; j++ { - z = z.mul(z, z) + zz = zz.mul(z, z) + zz, z = z, zz if v&mask != 0 { - z = z.mul(z, x) + zz = zz.mul(z, x) + zz, z = z, zz } - if m != nil { - q, z = q.div(z, z, m) + if len(m) != 0 { + zz, r = zz.div(r, z, m) + zz, r, q, z = q, z, zz, r } v <<= 1 @@ -1294,6 +1313,69 @@ func (z nat) expNN(x, y, m nat) nat { return z.norm() } +// expNNWindowed calculates x**y mod m using a fixed, 4-bit window. +func (z nat) expNNWindowed(x, y, m nat) nat { + // zz and r are used to avoid allocating in mul and div as otherwise + // the arguments would alias. + var zz, r nat + + const n = 4 + // powers[i] contains x^i. + var powers [1 << n]nat + powers[0] = natOne + powers[1] = x + for i := 2; i < 1<<n; i += 2 { + p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1] + *p = p.mul(*p2, *p2) + zz, r = zz.div(r, *p, m) + *p, r = r, *p + *p1 = p1.mul(*p, x) + zz, r = zz.div(r, *p1, m) + *p1, r = r, *p1 + } + + z = z.setWord(1) + + for i := len(y) - 1; i >= 0; i-- { + yi := y[i] + for j := 0; j < _W; j += n { + if i != len(y)-1 || j != 0 { + // Unrolled loop for significant performance + // gain. Use go test -bench=".*" in crypto/rsa + // to check performance before making changes. + zz = zz.mul(z, z) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + + zz = zz.mul(z, z) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + + zz = zz.mul(z, z) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + + zz = zz.mul(z, z) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + } + + zz = zz.mul(z, powers[yi>>(_W-n)]) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + + yi <<= n + } + } + + return z.norm() +} + // probablyPrime performs reps Miller-Rabin tests to check whether n is prime. // If it returns true, n is prime with probability 1 - 1/4^reps. // If it returns false, n is not prime. @@ -1343,8 +1425,9 @@ func (n nat) probablyPrime(reps int) bool { } nm1 := nat(nil).sub(n, natOne) - // 1<<k * q = nm1; - q, k := nm1.powersOfTwoDecompose() + // determine q, k such that nm1 = q << k + k := nm1.trailingZeroBits() + q := nat(nil).shr(nm1, k) nm3 := nat(nil).sub(nm1, natTwo) rand := rand.New(rand.NewSource(int64(n[0]))) @@ -1360,7 +1443,7 @@ NextRandom: if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 { continue } - for j := 1; j < k; j++ { + for j := uint(1); j < k; j++ { y = y.mul(y, y) quotient, y = quotient.div(y, y, n) if y.cmp(nm1) == 0 { diff --git a/src/pkg/math/big/nat_test.go b/src/pkg/math/big/nat_test.go index becde5d17..2dd7bf639 100644 --- a/src/pkg/math/big/nat_test.go +++ b/src/pkg/math/big/nat_test.go @@ -6,6 +6,7 @@ package big import ( "io" + "runtime" "strings" "testing" ) @@ -62,6 +63,36 @@ var prodNN = []argNN{ {nat{0, 0, 991 * 991}, nat{0, 991}, nat{0, 991}}, {nat{1 * 991, 2 * 991, 3 * 991, 4 * 991}, nat{1, 2, 3, 4}, nat{991}}, {nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}}, + // 3^100 * 3^28 = 3^128 + { + natFromString("11790184577738583171520872861412518665678211592275841109096961"), + natFromString("515377520732011331036461129765621272702107522001"), + natFromString("22876792454961"), + }, + // z = 111....1 (70000 digits) + // x = 10^(99*700) + ... + 10^1400 + 10^700 + 1 + // y = 111....1 (700 digits, larger than Karatsuba threshold on 32-bit and 64-bit) + { + natFromString(strings.Repeat("1", 70000)), + natFromString("1" + strings.Repeat(strings.Repeat("0", 699)+"1", 99)), + natFromString(strings.Repeat("1", 700)), + }, + // z = 111....1 (20000 digits) + // x = 10^10000 + 1 + // y = 111....1 (10000 digits) + { + natFromString(strings.Repeat("1", 20000)), + natFromString("1" + strings.Repeat("0", 9999) + "1"), + natFromString(strings.Repeat("1", 10000)), + }, +} + +func natFromString(s string) nat { + x, _, err := nat(nil).scan(strings.NewReader(s), 0) + if err != nil { + panic(err) + } + return x } func TestSet(t *testing.T) { @@ -135,26 +166,43 @@ func TestMulRangeN(t *testing.T) { } } -var mulArg, mulTmp nat - -func init() { - const n = 1000 - mulArg = make(nat, n) - for i := 0; i < n; i++ { - mulArg[i] = _M +// allocBytes returns the number of bytes allocated by invoking f. +func allocBytes(f func()) uint64 { + var stats runtime.MemStats + runtime.ReadMemStats(&stats) + t := stats.TotalAlloc + f() + runtime.ReadMemStats(&stats) + return stats.TotalAlloc - t +} + +// TestMulUnbalanced tests that multiplying numbers of different lengths +// does not cause deep recursion and in turn allocate too much memory. +// Test case for issue 3807. +func TestMulUnbalanced(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + x := rndNat(50000) + y := rndNat(40) + allocSize := allocBytes(func() { + nat(nil).mul(x, y) + }) + inputSize := uint64(len(x)+len(y)) * _S + if ratio := allocSize / uint64(inputSize); ratio > 10 { + t.Errorf("multiplication uses too much memory (%d > %d times the size of inputs)", allocSize, ratio) } } -func benchmarkMulLoad() { - for j := 1; j <= 10; j++ { - x := mulArg[0 : j*100] - mulTmp.mul(x, x) - } +func rndNat(n int) nat { + return nat(rndV(n)).norm() } func BenchmarkMul(b *testing.B) { + mulx := rndNat(1e4) + muly := rndNat(1e4) + b.ResetTimer() for i := 0; i < b.N; i++ { - benchmarkMulLoad() + var z nat + z.mul(mulx, muly) } } @@ -362,6 +410,20 @@ func TestScanPi(t *testing.T) { } } +func TestScanPiParallel(t *testing.T) { + const n = 2 + c := make(chan int) + for i := 0; i < n; i++ { + go func() { + TestScanPi(t) + c <- 0 + }() + } + for i := 0; i < n; i++ { + <-c + } +} + func BenchmarkScanPi(b *testing.B) { for i := 0; i < b.N; i++ { var x nat @@ -369,6 +431,28 @@ func BenchmarkScanPi(b *testing.B) { } } +func BenchmarkStringPiParallel(b *testing.B) { + var x nat + x, _, _ = x.scan(strings.NewReader(pi), 0) + if x.decimalString() != pi { + panic("benchmark incorrect: conversion failed") + } + n := runtime.GOMAXPROCS(0) + m := b.N / n // n*m <= b.N due to flooring, but the error is neglibible (n is not very large) + c := make(chan int, n) + for i := 0; i < n; i++ { + go func() { + for j := 0; j < m; j++ { + x.decimalString() + } + c <- 0 + }() + } + for i := 0; i < n; i++ { + <-c + } +} + func BenchmarkScan10Base2(b *testing.B) { ScanHelper(b, 2, 10, 10) } func BenchmarkScan100Base2(b *testing.B) { ScanHelper(b, 2, 10, 100) } func BenchmarkScan1000Base2(b *testing.B) { ScanHelper(b, 2, 10, 1000) } @@ -463,13 +547,13 @@ func BenchmarkLeafSize13(b *testing.B) { LeafSizeHelper(b, 10, 13) } func BenchmarkLeafSize14(b *testing.B) { LeafSizeHelper(b, 10, 14) } func BenchmarkLeafSize15(b *testing.B) { LeafSizeHelper(b, 10, 15) } func BenchmarkLeafSize16(b *testing.B) { LeafSizeHelper(b, 10, 16) } -func BenchmarkLeafSize32(b *testing.B) { LeafSizeHelper(b, 10, 32) } // try some large lengths +func BenchmarkLeafSize32(b *testing.B) { LeafSizeHelper(b, 10, 32) } // try some large lengths func BenchmarkLeafSize64(b *testing.B) { LeafSizeHelper(b, 10, 64) } func LeafSizeHelper(b *testing.B, base Word, size int) { b.StopTimer() originalLeafSize := leafSize - resetTable(cacheBase10[:]) + resetTable(cacheBase10.table[:]) leafSize = size b.StartTimer() @@ -486,7 +570,7 @@ func LeafSizeHelper(b *testing.B, base Word, size int) { } b.StopTimer() - resetTable(cacheBase10[:]) + resetTable(cacheBase10.table[:]) leafSize = originalLeafSize b.StartTimer() } @@ -616,14 +700,23 @@ func TestModW(t *testing.T) { } func TestTrailingZeroBits(t *testing.T) { - var x Word - x-- - for i := 0; i < _W; i++ { - if trailingZeroBits(x) != i { - t.Errorf("Failed at step %d: x: %x got: %d", i, x, trailingZeroBits(x)) + x := Word(1) + for i := uint(0); i <= _W; i++ { + n := trailingZeroBits(x) + if n != i%_W { + t.Errorf("got trailingZeroBits(%#x) = %d; want %d", x, n, i%_W) } x <<= 1 } + + y := nat(nil).set(natOne) + for i := uint(0); i <= 3*_W; i++ { + n := y.trailingZeroBits() + if n != i { + t.Errorf("got 0x%s.trailingZeroBits() = %d; want %d", y.string(lowercaseDigits[0:16]), n, i) + } + y = y.shl(y, 1) + } } var expNNTests = []struct { diff --git a/src/pkg/math/big/rat.go b/src/pkg/math/big/rat.go index 7bd83fc0f..3e6473d92 100644 --- a/src/pkg/math/big/rat.go +++ b/src/pkg/math/big/rat.go @@ -10,14 +10,17 @@ import ( "encoding/binary" "errors" "fmt" + "math" "strings" ) // A Rat represents a quotient a/b of arbitrary precision. // The zero value for a Rat represents the value 0. type Rat struct { - a Int - b nat // len(b) == 0 acts like b == 1 + // To make zero values for Rat work w/o initialization, + // a zero value of b (len(b) == 0) acts like b == 1. + // a.neg determines the sign of the Rat, b.neg is ignored. + a, b Int } // NewRat creates a new Rat with numerator a and denominator b. @@ -25,6 +28,156 @@ func NewRat(a, b int64) *Rat { return new(Rat).SetFrac64(a, b) } +// SetFloat64 sets z to exactly f and returns z. +// If f is not finite, SetFloat returns nil. +func (z *Rat) SetFloat64(f float64) *Rat { + const expMask = 1<<11 - 1 + bits := math.Float64bits(f) + mantissa := bits & (1<<52 - 1) + exp := int((bits >> 52) & expMask) + switch exp { + case expMask: // non-finite + return nil + case 0: // denormal + exp -= 1022 + default: // normal + mantissa |= 1 << 52 + exp -= 1023 + } + + shift := 52 - exp + + // Optimisation (?): partially pre-normalise. + for mantissa&1 == 0 && shift > 0 { + mantissa >>= 1 + shift-- + } + + z.a.SetUint64(mantissa) + z.a.neg = f < 0 + z.b.Set(intOne) + if shift > 0 { + z.b.Lsh(&z.b, uint(shift)) + } else { + z.a.Lsh(&z.a, uint(-shift)) + } + return z.norm() +} + +// isFinite reports whether f represents a finite rational value. +// It is equivalent to !math.IsNan(f) && !math.IsInf(f, 0). +func isFinite(f float64) bool { + return math.Abs(f) <= math.MaxFloat64 +} + +// low64 returns the least significant 64 bits of natural number z. +func low64(z nat) uint64 { + if len(z) == 0 { + return 0 + } + if _W == 32 && len(z) > 1 { + return uint64(z[1])<<32 | uint64(z[0]) + } + return uint64(z[0]) +} + +// quotToFloat returns the non-negative IEEE 754 double-precision +// value nearest to the quotient a/b, using round-to-even in halfway +// cases. It does not mutate its arguments. +// Preconditions: b is non-zero; a and b have no common factors. +func quotToFloat(a, b nat) (f float64, exact bool) { + // TODO(adonovan): specialize common degenerate cases: 1.0, integers. + alen := a.bitLen() + if alen == 0 { + return 0, true + } + blen := b.bitLen() + if blen == 0 { + panic("division by zero") + } + + // 1. Left-shift A or B such that quotient A/B is in [1<<53, 1<<55). + // (54 bits if A<B when they are left-aligned, 55 bits if A>=B.) + // This is 2 or 3 more than the float64 mantissa field width of 52: + // - the optional extra bit is shifted away in step 3 below. + // - the high-order 1 is omitted in float64 "normal" representation; + // - the low-order 1 will be used during rounding then discarded. + exp := alen - blen + var a2, b2 nat + a2 = a2.set(a) + b2 = b2.set(b) + if shift := 54 - exp; shift > 0 { + a2 = a2.shl(a2, uint(shift)) + } else if shift < 0 { + b2 = b2.shl(b2, uint(-shift)) + } + + // 2. Compute quotient and remainder (q, r). NB: due to the + // extra shift, the low-order bit of q is logically the + // high-order bit of r. + var q nat + q, r := q.div(a2, a2, b2) // (recycle a2) + mantissa := low64(q) + haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half + + // 3. If quotient didn't fit in 54 bits, re-do division by b2<<1 + // (in effect---we accomplish this incrementally). + if mantissa>>54 == 1 { + if mantissa&1 == 1 { + haveRem = true + } + mantissa >>= 1 + exp++ + } + if mantissa>>53 != 1 { + panic("expected exactly 54 bits of result") + } + + // 4. Rounding. + if -1022-52 <= exp && exp <= -1022 { + // Denormal case; lose 'shift' bits of precision. + shift := uint64(-1022 - (exp - 1)) // [1..53) + lostbits := mantissa & (1<<shift - 1) + haveRem = haveRem || lostbits != 0 + mantissa >>= shift + exp = -1023 + 2 + } + // Round q using round-half-to-even. + exact = !haveRem + if mantissa&1 != 0 { + exact = false + if haveRem || mantissa&2 != 0 { + if mantissa++; mantissa >= 1<<54 { + // Complete rollover 11...1 => 100...0, so shift is safe + mantissa >>= 1 + exp++ + } + } + } + mantissa >>= 1 // discard rounding bit. Mantissa now scaled by 2^53. + + f = math.Ldexp(float64(mantissa), exp-53) + if math.IsInf(f, 0) { + exact = false + } + return +} + +// Float64 returns the nearest float64 value to z. +// If z is exactly representable as a float64, Float64 returns exact=true. +// If z is negative, so too is f, even if f==0. +func (z *Rat) Float64() (f float64, exact bool) { + b := z.b.abs + if len(b) == 0 { + b = b.set(natOne) // materialize denominator + } + f, exact = quotToFloat(z.a.abs, b) + if z.a.neg { + f = -f + } + return +} + // SetFrac sets z to a/b and returns z. func (z *Rat) SetFrac(a, b *Int) *Rat { z.a.neg = a.neg != b.neg @@ -36,7 +189,7 @@ func (z *Rat) SetFrac(a, b *Int) *Rat { babs = nat(nil).set(babs) // make a copy } z.a.abs = z.a.abs.set(a.abs) - z.b = z.b.set(babs) + z.b.abs = z.b.abs.set(babs) return z.norm() } @@ -50,21 +203,21 @@ func (z *Rat) SetFrac64(a, b int64) *Rat { b = -b z.a.neg = !z.a.neg } - z.b = z.b.setUint64(uint64(b)) + z.b.abs = z.b.abs.setUint64(uint64(b)) return z.norm() } // SetInt sets z to x (by making a copy of x) and returns z. func (z *Rat) SetInt(x *Int) *Rat { z.a.Set(x) - z.b = z.b.make(0) + z.b.abs = z.b.abs.make(0) return z } // SetInt64 sets z to x and returns z. func (z *Rat) SetInt64(x int64) *Rat { z.a.SetInt64(x) - z.b = z.b.make(0) + z.b.abs = z.b.abs.make(0) return z } @@ -72,7 +225,7 @@ func (z *Rat) SetInt64(x int64) *Rat { func (z *Rat) Set(x *Rat) *Rat { if z != x { z.a.Set(&x.a) - z.b = z.b.set(x.b) + z.b.Set(&x.b) } return z } @@ -97,15 +250,15 @@ func (z *Rat) Inv(x *Rat) *Rat { panic("division by zero") } z.Set(x) - a := z.b + a := z.b.abs if len(a) == 0 { - a = a.setWord(1) // materialize numerator + a = a.set(natOne) // materialize numerator } b := z.a.abs if b.cmp(natOne) == 0 { b = b.make(0) // normalize denominator } - z.a.abs, z.b = a, b // sign doesn't change + z.a.abs, z.b.abs = a, b // sign doesn't change return z } @@ -121,38 +274,26 @@ func (x *Rat) Sign() int { // IsInt returns true if the denominator of x is 1. func (x *Rat) IsInt() bool { - return len(x.b) == 0 || x.b.cmp(natOne) == 0 + return len(x.b.abs) == 0 || x.b.abs.cmp(natOne) == 0 } // Num returns the numerator of x; it may be <= 0. // The result is a reference to x's numerator; it -// may change if a new value is assigned to x. +// may change if a new value is assigned to x, and vice versa. +// The sign of the numerator corresponds to the sign of x. func (x *Rat) Num() *Int { return &x.a } // Denom returns the denominator of x; it is always > 0. // The result is a reference to x's denominator; it -// may change if a new value is assigned to x. +// may change if a new value is assigned to x, and vice versa. func (x *Rat) Denom() *Int { - if len(x.b) == 0 { - return &Int{abs: nat{1}} + x.b.neg = false // the result is always >= 0 + if len(x.b.abs) == 0 { + x.b.abs = x.b.abs.set(natOne) // materialize denominator } - return &Int{abs: x.b} -} - -func gcd(x, y nat) nat { - // Euclidean algorithm. - var a, b nat - a = a.set(x) - b = b.set(y) - for len(b) != 0 { - var q, r nat - _, r = q.div(r, a, b) - a = b - b = r - } - return a + return &x.b } func (z *Rat) norm() *Rat { @@ -160,17 +301,25 @@ func (z *Rat) norm() *Rat { case len(z.a.abs) == 0: // z == 0 - normalize sign and denominator z.a.neg = false - z.b = z.b.make(0) - case len(z.b) == 0: + z.b.abs = z.b.abs.make(0) + case len(z.b.abs) == 0: // z is normalized int - nothing to do - case z.b.cmp(natOne) == 0: + case z.b.abs.cmp(natOne) == 0: // z is int - normalize denominator - z.b = z.b.make(0) + z.b.abs = z.b.abs.make(0) default: - if f := gcd(z.a.abs, z.b); f.cmp(natOne) != 0 { - z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f) - z.b, _ = z.b.div(nil, z.b, f) + neg := z.a.neg + z.a.neg = false + z.b.neg = false + if f := NewInt(0).binaryGCD(&z.a, &z.b); f.Cmp(intOne) != 0 { + z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f.abs) + z.b.abs, _ = z.b.abs.div(nil, z.b.abs, f.abs) + if z.b.abs.cmp(natOne) == 0 { + // z is int - normalize denominator + z.b.abs = z.b.abs.make(0) + } } + z.a.neg = neg } return z } @@ -207,31 +356,31 @@ func scaleDenom(x *Int, f nat) *Int { // +1 if x > y // func (x *Rat) Cmp(y *Rat) int { - return scaleDenom(&x.a, y.b).Cmp(scaleDenom(&y.a, x.b)) + return scaleDenom(&x.a, y.b.abs).Cmp(scaleDenom(&y.a, x.b.abs)) } // Add sets z to the sum x+y and returns z. func (z *Rat) Add(x, y *Rat) *Rat { - a1 := scaleDenom(&x.a, y.b) - a2 := scaleDenom(&y.a, x.b) + a1 := scaleDenom(&x.a, y.b.abs) + a2 := scaleDenom(&y.a, x.b.abs) z.a.Add(a1, a2) - z.b = mulDenom(z.b, x.b, y.b) + z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) return z.norm() } // Sub sets z to the difference x-y and returns z. func (z *Rat) Sub(x, y *Rat) *Rat { - a1 := scaleDenom(&x.a, y.b) - a2 := scaleDenom(&y.a, x.b) + a1 := scaleDenom(&x.a, y.b.abs) + a2 := scaleDenom(&y.a, x.b.abs) z.a.Sub(a1, a2) - z.b = mulDenom(z.b, x.b, y.b) + z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) return z.norm() } // Mul sets z to the product x*y and returns z. func (z *Rat) Mul(x, y *Rat) *Rat { z.a.Mul(&x.a, &y.a) - z.b = mulDenom(z.b, x.b, y.b) + z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) return z.norm() } @@ -241,10 +390,10 @@ func (z *Rat) Quo(x, y *Rat) *Rat { if len(y.a.abs) == 0 { panic("division by zero") } - a := scaleDenom(&x.a, y.b) - b := scaleDenom(&y.a, x.b) + a := scaleDenom(&x.a, y.b.abs) + b := scaleDenom(&y.a, x.b.abs) z.a.abs = a.abs - z.b = b.abs + z.b.abs = b.abs z.a.neg = a.neg != b.neg return z.norm() } @@ -286,7 +435,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) { } s = s[sep+1:] var err error - if z.b, _, err = z.b.scan(strings.NewReader(s), 10); err != nil { + if z.b.abs, _, err = z.b.abs.scan(strings.NewReader(s), 10); err != nil { return nil, false } return z.norm(), true @@ -317,11 +466,11 @@ func (z *Rat) SetString(s string) (*Rat, bool) { } powTen := nat(nil).expNN(natTen, exp.abs, nil) if exp.neg { - z.b = powTen + z.b.abs = powTen z.norm() } else { z.a.abs = z.a.abs.mul(z.a.abs, powTen) - z.b = z.b.make(0) + z.b.abs = z.b.abs.make(0) } return z, true @@ -330,8 +479,8 @@ func (z *Rat) SetString(s string) (*Rat, bool) { // String returns a string representation of z in the form "a/b" (even if b == 1). func (x *Rat) String() string { s := "/1" - if len(x.b) != 0 { - s = "/" + x.b.decimalString() + if len(x.b.abs) != 0 { + s = "/" + x.b.abs.decimalString() } return x.a.String() + s } @@ -355,9 +504,9 @@ func (x *Rat) FloatString(prec int) string { } return s } - // x.b != 0 + // x.b.abs != 0 - q, r := nat(nil).div(nat(nil), x.a.abs, x.b) + q, r := nat(nil).div(nat(nil), x.a.abs, x.b.abs) p := natOne if prec > 0 { @@ -365,11 +514,11 @@ func (x *Rat) FloatString(prec int) string { } r = r.mul(r, p) - r, r2 := r.div(nat(nil), r, x.b) + r, r2 := r.div(nat(nil), r, x.b.abs) // see if we need to round up r2 = r2.add(r2, r2) - if x.b.cmp(r2) <= 0 { + if x.b.abs.cmp(r2) <= 0 { r = r.add(r, natOne) if r.cmp(p) >= 0 { q = nat(nil).add(q, natOne) @@ -396,8 +545,8 @@ const ratGobVersion byte = 1 // GobEncode implements the gob.GobEncoder interface. func (x *Rat) GobEncode() ([]byte, error) { - buf := make([]byte, 1+4+(len(x.a.abs)+len(x.b))*_S) // extra bytes for version and sign bit (1), and numerator length (4) - i := x.b.bytes(buf) + buf := make([]byte, 1+4+(len(x.a.abs)+len(x.b.abs))*_S) // extra bytes for version and sign bit (1), and numerator length (4) + i := x.b.abs.bytes(buf) j := x.a.abs.bytes(buf[0:i]) n := i - j if int(uint32(n)) != n { @@ -427,6 +576,6 @@ func (z *Rat) GobDecode(buf []byte) error { i := j + binary.BigEndian.Uint32(buf[j-4:j]) z.a.neg = b&1 != 0 z.a.abs = z.a.abs.setBytes(buf[j:i]) - z.b = z.b.setBytes(buf[i:]) + z.b.abs = z.b.abs.setBytes(buf[i:]) return nil } diff --git a/src/pkg/math/big/rat_test.go b/src/pkg/math/big/rat_test.go index f7f31ae1a..462dfb723 100644 --- a/src/pkg/math/big/rat_test.go +++ b/src/pkg/math/big/rat_test.go @@ -8,6 +8,9 @@ import ( "bytes" "encoding/gob" "fmt" + "math" + "strconv" + "strings" "testing" ) @@ -387,30 +390,19 @@ func TestRatGobEncoding(t *testing.T) { var medium bytes.Buffer enc := gob.NewEncoder(&medium) dec := gob.NewDecoder(&medium) - for i, test := range gobEncodingTests { - for j := 0; j < 4; j++ { - medium.Reset() // empty buffer for each test case (in case of failures) - stest := test - if j&1 != 0 { - // negative numbers - stest = "-" + test - } - if j%2 != 0 { - // fractions - stest = stest + "." + test - } - var tx Rat - tx.SetString(stest) - if err := enc.Encode(&tx); err != nil { - t.Errorf("#%d%c: encoding failed: %s", i, 'a'+j, err) - } - var rx Rat - if err := dec.Decode(&rx); err != nil { - t.Errorf("#%d%c: decoding failed: %s", i, 'a'+j, err) - } - if rx.Cmp(&tx) != 0 { - t.Errorf("#%d%c: transmission failed: got %s want %s", i, 'a'+j, &rx, &tx) - } + for _, test := range encodingTests { + medium.Reset() // empty buffer for each test case (in case of failures) + var tx Rat + tx.SetString(test + ".14159265") + if err := enc.Encode(&tx); err != nil { + t.Errorf("encoding of %s failed: %s", &tx, err) + } + var rx Rat + if err := dec.Decode(&rx); err != nil { + t.Errorf("decoding of %s failed: %s", &tx, err) + } + if rx.Cmp(&tx) != 0 { + t.Errorf("transmission of %s failed: got %s want %s", &tx, &rx, &tx) } } } @@ -454,3 +446,462 @@ func TestIssue2379(t *testing.T) { t.Errorf("5) got %s want %s", x, q) } } + +func TestIssue3521(t *testing.T) { + a := new(Int) + b := new(Int) + a.SetString("64375784358435883458348587", 0) + b.SetString("4789759874531", 0) + + // 0) a raw zero value has 1 as denominator + zero := new(Rat) + one := NewInt(1) + if zero.Denom().Cmp(one) != 0 { + t.Errorf("0) got %s want %s", zero.Denom(), one) + } + + // 1a) a zero value remains zero independent of denominator + x := new(Rat) + x.Denom().Set(new(Int).Neg(b)) + if x.Cmp(zero) != 0 { + t.Errorf("1a) got %s want %s", x, zero) + } + + // 1b) a zero value may have a denominator != 0 and != 1 + x.Num().Set(a) + qab := new(Rat).SetFrac(a, b) + if x.Cmp(qab) != 0 { + t.Errorf("1b) got %s want %s", x, qab) + } + + // 2a) an integral value becomes a fraction depending on denominator + x.SetFrac64(10, 2) + x.Denom().SetInt64(3) + q53 := NewRat(5, 3) + if x.Cmp(q53) != 0 { + t.Errorf("2a) got %s want %s", x, q53) + } + + // 2b) an integral value becomes a fraction depending on denominator + x = NewRat(10, 2) + x.Denom().SetInt64(3) + if x.Cmp(q53) != 0 { + t.Errorf("2b) got %s want %s", x, q53) + } + + // 3) changing the numerator/denominator of a Rat changes the Rat + x.SetFrac(a, b) + a = x.Num() + b = x.Denom() + a.SetInt64(5) + b.SetInt64(3) + if x.Cmp(q53) != 0 { + t.Errorf("3) got %s want %s", x, q53) + } +} + +// Test inputs to Rat.SetString. The prefix "long:" causes the test +// to be skipped in --test.short mode. (The threshold is about 500us.) +var float64inputs = []string{ + // + // Constants plundered from strconv/testfp.txt. + // + + // Table 1: Stress Inputs for Conversion to 53-bit Binary, < 1/2 ULP + "5e+125", + "69e+267", + "999e-026", + "7861e-034", + "75569e-254", + "928609e-261", + "9210917e+080", + "84863171e+114", + "653777767e+273", + "5232604057e-298", + "27235667517e-109", + "653532977297e-123", + "3142213164987e-294", + "46202199371337e-072", + "231010996856685e-073", + "9324754620109615e+212", + "78459735791271921e+049", + "272104041512242479e+200", + "6802601037806061975e+198", + "20505426358836677347e-221", + "836168422905420598437e-234", + "4891559871276714924261e+222", + + // Table 2: Stress Inputs for Conversion to 53-bit Binary, > 1/2 ULP + "9e-265", + "85e-037", + "623e+100", + "3571e+263", + "81661e+153", + "920657e-023", + "4603285e-024", + "87575437e-309", + "245540327e+122", + "6138508175e+120", + "83356057653e+193", + "619534293513e+124", + "2335141086879e+218", + "36167929443327e-159", + "609610927149051e-255", + "3743626360493413e-165", + "94080055902682397e-242", + "899810892172646163e+283", + "7120190517612959703e+120", + "25188282901709339043e-252", + "308984926168550152811e-052", + "6372891218502368041059e+064", + + // Table 14: Stress Inputs for Conversion to 24-bit Binary, <1/2 ULP + "5e-20", + "67e+14", + "985e+15", + "7693e-42", + "55895e-16", + "996622e-44", + "7038531e-32", + "60419369e-46", + "702990899e-20", + "6930161142e-48", + "25933168707e+13", + "596428896559e+20", + + // Table 15: Stress Inputs for Conversion to 24-bit Binary, >1/2 ULP + "3e-23", + "57e+18", + "789e-35", + "2539e-18", + "76173e+28", + "887745e-11", + "5382571e-37", + "82381273e-35", + "750486563e-38", + "3752432815e-39", + "75224575729e-45", + "459926601011e+15", + + // + // Constants plundered from strconv/atof_test.go. + // + + "0", + "1", + "+1", + "1e23", + "1E23", + "100000000000000000000000", + "1e-100", + "123456700", + "99999999999999974834176", + "100000000000000000000001", + "100000000000000008388608", + "100000000000000016777215", + "100000000000000016777216", + "-1", + "-0.1", + "-0", // NB: exception made for this input + "1e-20", + "625e-3", + + // largest float64 + "1.7976931348623157e308", + "-1.7976931348623157e308", + // next float64 - too large + "1.7976931348623159e308", + "-1.7976931348623159e308", + // the border is ...158079 + // borderline - okay + "1.7976931348623158e308", + "-1.7976931348623158e308", + // borderline - too large + "1.797693134862315808e308", + "-1.797693134862315808e308", + + // a little too large + "1e308", + "2e308", + "1e309", + + // way too large + "1e310", + "-1e310", + "1e400", + "-1e400", + "long:1e400000", + "long:-1e400000", + + // denormalized + "1e-305", + "1e-306", + "1e-307", + "1e-308", + "1e-309", + "1e-310", + "1e-322", + // smallest denormal + "5e-324", + "4e-324", + "3e-324", + // too small + "2e-324", + // way too small + "1e-350", + "long:1e-400000", + // way too small, negative + "-1e-350", + "long:-1e-400000", + + // try to overflow exponent + // [Disabled: too slow and memory-hungry with rationals.] + // "1e-4294967296", + // "1e+4294967296", + // "1e-18446744073709551616", + // "1e+18446744073709551616", + + // http://www.exploringbinary.com/java-hangs-when-converting-2-2250738585072012e-308/ + "2.2250738585072012e-308", + // http://www.exploringbinary.com/php-hangs-on-numeric-value-2-2250738585072011e-308/ + + "2.2250738585072011e-308", + + // A very large number (initially wrongly parsed by the fast algorithm). + "4.630813248087435e+307", + + // A different kind of very large number. + "22.222222222222222", + "long:2." + strings.Repeat("2", 4000) + "e+1", + + // Exactly halfway between 1 and math.Nextafter(1, 2). + // Round to even (down). + "1.00000000000000011102230246251565404236316680908203125", + // Slightly lower; still round down. + "1.00000000000000011102230246251565404236316680908203124", + // Slightly higher; round up. + "1.00000000000000011102230246251565404236316680908203126", + // Slightly higher, but you have to read all the way to the end. + "long:1.00000000000000011102230246251565404236316680908203125" + strings.Repeat("0", 10000) + "1", + + // Smallest denormal, 2^(-1022-52) + "4.940656458412465441765687928682213723651e-324", + // Half of smallest denormal, 2^(-1022-53) + "2.470328229206232720882843964341106861825e-324", + // A little more than the exact half of smallest denormal + // 2^-1075 + 2^-1100. (Rounds to 1p-1074.) + "2.470328302827751011111470718709768633275e-324", + // The exact halfway between smallest normal and largest denormal: + // 2^-1022 - 2^-1075. (Rounds to 2^-1022.) + "2.225073858507201136057409796709131975935e-308", + + "1152921504606846975", // 1<<60 - 1 + "-1152921504606846975", // -(1<<60 - 1) + "1152921504606846977", // 1<<60 + 1 + "-1152921504606846977", // -(1<<60 + 1) + + "1/3", +} + +func TestFloat64SpecialCases(t *testing.T) { + for _, input := range float64inputs { + if strings.HasPrefix(input, "long:") { + if testing.Short() { + continue + } + input = input[len("long:"):] + } + + r, ok := new(Rat).SetString(input) + if !ok { + t.Errorf("Rat.SetString(%q) failed", input) + continue + } + f, exact := r.Float64() + + // 1. Check string -> Rat -> float64 conversions are + // consistent with strconv.ParseFloat. + // Skip this check if the input uses "a/b" rational syntax. + if !strings.Contains(input, "/") { + e, _ := strconv.ParseFloat(input, 64) + + // Careful: negative Rats too small for + // float64 become -0, but Rat obviously cannot + // preserve the sign from SetString("-0"). + switch { + case math.Float64bits(e) == math.Float64bits(f): + // Ok: bitwise equal. + case f == 0 && r.Num().BitLen() == 0: + // Ok: Rat(0) is equivalent to both +/- float64(0). + default: + t.Errorf("strconv.ParseFloat(%q) = %g (%b), want %g (%b); delta=%g", input, e, e, f, f, f-e) + } + } + + if !isFinite(f) { + continue + } + + // 2. Check f is best approximation to r. + if !checkIsBestApprox(t, f, r) { + // Append context information. + t.Errorf("(input was %q)", input) + } + + // 3. Check f->R->f roundtrip is non-lossy. + checkNonLossyRoundtrip(t, f) + + // 4. Check exactness using slow algorithm. + if wasExact := new(Rat).SetFloat64(f).Cmp(r) == 0; wasExact != exact { + t.Errorf("Rat.SetString(%q).Float64().exact = %t, want %t", input, exact, wasExact) + } + } +} + +func TestFloat64Distribution(t *testing.T) { + // Generate a distribution of (sign, mantissa, exp) values + // broader than the float64 range, and check Rat.Float64() + // always picks the closest float64 approximation. + var add = []int64{ + 0, + 1, + 3, + 5, + 7, + 9, + 11, + } + var winc, einc = uint64(1), int(1) // soak test (~75s on x86-64) + if testing.Short() { + winc, einc = 10, 500 // quick test (~12ms on x86-64) + } + + for _, sign := range "+-" { + for _, a := range add { + for wid := uint64(0); wid < 60; wid += winc { + b := int64(1<<wid + a) + if sign == '-' { + b = -b + } + for exp := -1100; exp < 1100; exp += einc { + num, den := NewInt(b), NewInt(1) + if exp > 0 { + num.Lsh(num, uint(exp)) + } else { + den.Lsh(den, uint(-exp)) + } + r := new(Rat).SetFrac(num, den) + f, _ := r.Float64() + + if !checkIsBestApprox(t, f, r) { + // Append context information. + t.Errorf("(input was mantissa %#x, exp %d; f=%g (%b); f~%g; r=%v)", + b, exp, f, f, math.Ldexp(float64(b), exp), r) + } + + checkNonLossyRoundtrip(t, f) + } + } + } + } +} + +// TestFloat64NonFinite checks that SetFloat64 of a non-finite value +// returns nil. +func TestSetFloat64NonFinite(t *testing.T) { + for _, f := range []float64{math.NaN(), math.Inf(+1), math.Inf(-1)} { + var r Rat + if r2 := r.SetFloat64(f); r2 != nil { + t.Errorf("SetFloat64(%g) was %v, want nil", f, r2) + } + } +} + +// checkNonLossyRoundtrip checks that a float->Rat->float roundtrip is +// non-lossy for finite f. +func checkNonLossyRoundtrip(t *testing.T, f float64) { + if !isFinite(f) { + return + } + r := new(Rat).SetFloat64(f) + if r == nil { + t.Errorf("Rat.SetFloat64(%g (%b)) == nil", f, f) + return + } + f2, exact := r.Float64() + if f != f2 || !exact { + t.Errorf("Rat.SetFloat64(%g).Float64() = %g (%b), %v, want %g (%b), %v; delta=%b", + f, f2, f2, exact, f, f, true, f2-f) + } +} + +// delta returns the absolute difference between r and f. +func delta(r *Rat, f float64) *Rat { + d := new(Rat).Sub(r, new(Rat).SetFloat64(f)) + return d.Abs(d) +} + +// checkIsBestApprox checks that f is the best possible float64 +// approximation of r. +// Returns true on success. +func checkIsBestApprox(t *testing.T, f float64, r *Rat) bool { + if math.Abs(f) >= math.MaxFloat64 { + // Cannot check +Inf, -Inf, nor the float next to them (MaxFloat64). + // But we have tests for these special cases. + return true + } + + // r must be strictly between f0 and f1, the floats bracketing f. + f0 := math.Nextafter(f, math.Inf(-1)) + f1 := math.Nextafter(f, math.Inf(+1)) + + // For f to be correct, r must be closer to f than to f0 or f1. + df := delta(r, f) + df0 := delta(r, f0) + df1 := delta(r, f1) + if df.Cmp(df0) > 0 { + t.Errorf("Rat(%v).Float64() = %g (%b), but previous float64 %g (%b) is closer", r, f, f, f0, f0) + return false + } + if df.Cmp(df1) > 0 { + t.Errorf("Rat(%v).Float64() = %g (%b), but next float64 %g (%b) is closer", r, f, f, f1, f1) + return false + } + if df.Cmp(df0) == 0 && !isEven(f) { + t.Errorf("Rat(%v).Float64() = %g (%b); halfway should have rounded to %g (%b) instead", r, f, f, f0, f0) + return false + } + if df.Cmp(df1) == 0 && !isEven(f) { + t.Errorf("Rat(%v).Float64() = %g (%b); halfway should have rounded to %g (%b) instead", r, f, f, f1, f1) + return false + } + return true +} + +func isEven(f float64) bool { return math.Float64bits(f)&1 == 0 } + +func TestIsFinite(t *testing.T) { + finites := []float64{ + 1.0 / 3, + 4891559871276714924261e+222, + math.MaxFloat64, + math.SmallestNonzeroFloat64, + -math.MaxFloat64, + -math.SmallestNonzeroFloat64, + } + for _, f := range finites { + if !isFinite(f) { + t.Errorf("!IsFinite(%g (%b))", f, f) + } + } + nonfinites := []float64{ + math.NaN(), + math.Inf(-1), + math.Inf(+1), + } + for _, f := range nonfinites { + if isFinite(f) { + t.Errorf("IsFinite(%g, (%b))", f, f) + } + } +} |