diff --git a/src/bigints.nim b/src/bigints.nim index 2b0ab82..0b5e701 100644 --- a/src/bigints.nim +++ b/src/bigints.nim @@ -64,7 +64,7 @@ func initBigInt*(val: BigInt): BigInt = const zero = initBigInt(0) one = initBigInt(1) - karatsubaTreshold = 5 + karatsubaTreshold = 10 func isZero(a: BigInt): bool {.inline.} = for i in countdown(a.limbs.high, 0): @@ -418,35 +418,25 @@ func unsignedMultiplication(a: var BigInt, b, c: BigInt) {.inline.} = inc pos normalize(a) -func scalarMultiplication(a: var BigInt, b: uint32, c: BigInt) {.inline.} = - # Based on unsignedMultiplication +func scalarMultiplication(a: var BigInt, b: BigInt, c: uint32) {.inline.} = + # always called with bl >= cl let - cl = c.limbs.len - a.limbs.setLen(1 + cl) + bl = b.limbs.len + a.limbs.setLen(bl + 1) var tmp = 0'u64 - tmp += uint64(b) * uint64(c.limbs[0]) - a.limbs[1] = uint32(tmp and uint32.high) - tmp = tmp shr 32 # carry - - a.limbs[1] = uint32(tmp) - - for j in 1 ..< cl: - tmp = 0'u64 - tmp += uint64(a.limbs[j]) + uint64(b) * uint64(c.limbs[j]) - a.limbs[j] = uint32(tmp and uint32.high) + for i in 0 ..< bl: + tmp += uint64(b.limbs[i]) * uint64(c) + a.limbs[i] = uint32(tmp and uint32.high) tmp = tmp shr 32 - var pos = j + 1 - while tmp > 0'u64: - tmp += uint64(a.limbs[pos]) - a.limbs[pos] = uint32(tmp and uint32.high) - tmp = tmp shr 32 - inc pos + + a.limbs[bl] = uint32(tmp) normalize(a) # forward declaration for use in `multiplication` -func unsignedKaratsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.} +func karatsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.} func `shl`*(x: BigInt, y: Natural): BigInt +func `shr`*(x: BigInt, y: Natural): BigInt func multiplication(a: var BigInt, b, c: BigInt) = # a = b * c @@ -459,28 +449,27 @@ func multiplication(a: var BigInt, b, c: BigInt) = if cl > bl: if bl <= karatsubaTreshold: - unsignedKaratsubaMultiplication(a, c, b) + karatsubaMultiplication(a, c, b) else: unsignedMultiplication(a, c, b) else: if cl <= karatsubaTreshold: - unsignedKaratsubaMultiplication(a, b, c) + karatsubaMultiplication(a, b, c) else: unsignedMultiplication(a, b, c) a.isNegative = b.isNegative xor c.isNegative -func `shr`*(x: BigInt, y: Natural): BigInt -func unsignedKaratsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.} = +func karatsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.} = let bl = b.limbs.len cl = c.limbs.len let n = max(bl, cl) if bl == 1: # base case : multiply the only limb with each limb of second term - scalarMultiplication(a, b.limbs[0], c) + scalarMultiplication(a, c, b.limbs[0]) return if cl == 1: - scalarMultiplication(a, c.limbs[0], b) + scalarMultiplication(a, b, c.limbs[0]) return if bl < karatsubaTreshold: if cl <= bl: @@ -507,21 +496,19 @@ func unsignedKaratsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.} = # limit carry handling in opposition to the additive version var lowProduct, highProduct, A3, A4, A5, middleTerm: BigInt = zero - unsignedKaratsubaMultiplication(lowProduct, low_b, low_c) - unsignedKaratsubaMultiplication(highProduct, high_b, high_c) + karatsubaMultiplication(lowProduct, low_b, low_c) + karatsubaMultiplication(highProduct, high_b, high_c) A3 = low_b - high_b # Additive variant of Karatsuba - A4 = high_c - low_c # would add them + A4 = low_c - high_c # would add them if A4.limbs.len >= A3.limbs.len: multiplication(A5, abs(A4), abs(A3)) else: multiplication(A5, abs(A3), abs(A4)) middleTerm = lowProduct + highProduct + A5 - a = lowProduct + (middleTerm shr k) + (highProduct shr (2*k)) - # We could affect directly some of the bits of the result with slicing - # a.limbs[0 .. k - 1] = lowProduct.limbs - # But the following instructions would not be correct due to sign handling - # a.limbs[k .. 2*k-1] = middleTerm.limbs - # a.limbs[2*k .. 3*k-1] = highProduct.limbs + a.limbs[0 .. k - 1] = lowProduct.limbs + # a += (middleTerm shr k) + (highProduct shr (2*k)) + a.limbs[k .. 2*k-1] = middleTerm.limbs + a.limbs[2*k .. 3*k-1] = highProduct.limbs func `*`*(a, b: BigInt): BigInt =