diff --git a/src/uint_binary_ops.nim b/src/uint_binary_ops.nim index 76ab59f..f56d1e8 100644 --- a/src/uint_binary_ops.nim +++ b/src/uint_binary_ops.nim @@ -36,15 +36,12 @@ proc `-`*[T: MpUint](a, b: T): T {.noSideEffect, noInit, inline.}= result = a result -= b -proc karatsuba[T: BaseUint](a, b: T): MpUint[T] {.noSideEffect, noInit, inline.} +proc naiveMul[T: BaseUint](a, b: T): MpUint[T] {.noSideEffect, noInit, inline.} # Forward declaration proc `*`*[T: MpUint](a, b: T): T {.noSideEffect, noInit.}= ## Multiplication for multi-precision unsigned uint # - # We use a modified Karatsuba algorithm - # - # Karatsuba algorithm splits the operand into `hi * B + lo` # For our representation, it is similar to school grade multiplication # Consider hi and lo as if they were digits # @@ -60,28 +57,40 @@ proc `*`*[T: MpUint](a, b: T): T {.noSideEffect, noInit.}= # # If T is a type # For T * T --> T we don't need to compute z2 as it always overflow - # For T * T --> 2T (uint64 * uint64 --> uint128) we use the full precision Karatsuba algorithm + # For T * T --> 2T (uint64 * uint64 --> uint128) we use extra precision multiplication - result = karatsuba(a.lo, b.lo) - result.hi += (karatsuba(a.hi, b.lo) + karatsuba(a.lo, b.hi)).lo + result = naiveMul(a.lo, b.lo) + result.hi += (naiveMul(a.hi, b.lo) + naiveMul(a.lo, b.hi)).lo -template karatsubaImpl[T: MpUint](x, y: T): MpUint[T] = +template naiveMulImpl[T: MpUint](x, y: T): MpUint[T] = + # See details at # https://en.wikipedia.org/wiki/Karatsuba_algorithm + # https://locklessinc.com/articles/256bit_arithmetic/ + # https://www.miracl.com/press/missing-a-trick-karatsuba-variations-michael-scott + # + # We use the naive school grade multiplication instead of Karatsuba I.e. + # z1 = x.hi * y.lo + x.lo * y.hi (Naive) = (x.lo - x.hi)(y.hi - y.lo) + z0 + z2 (Karatsuba) + # + # On modern architecture: + # - addition and multiplication have the same cost + # - Karatsuba would require to deal with potentially negative intermediate result + # and introduce branching + # - More total operations means more register moves const halfShl = T.sizeof div 2 let - z0 = karatsuba(x.lo, y.lo) - tmp = karatsuba(x.hi, y.lo) + z0 = naiveMul(x.lo, y.lo) + tmp = naiveMul(x.hi, y.lo) var z1 = tmp - z1 += karatsuba(x.hi, y.lo) - let z2 = (z1 < tmp).T + karatsuba(x.hi, y.hi) + z1 += naiveMul(x.hi, y.lo) + let z2 = (z1 < tmp).T + naiveMul(x.hi, y.hi) result.lo = z1.lo shl halfShl + z0 result.hi = z2 + z1.hi -proc karatsuba[T: BaseUint](a, b: T): MpUint[T] {.noSideEffect, noInit, inline.}= - ## Karatsuba algorithm with full precision +proc naiveMul[T: BaseUint](a, b: T): MpUint[T] {.noSideEffect, noInit, inline.}= + ## Naive multiplication algorithm with extended precision when T.sizeof in {1, 2, 4}: # Use types twice bigger to do the multiplication @@ -89,8 +98,7 @@ proc karatsuba[T: BaseUint](a, b: T): MpUint[T] {.noSideEffect, noInit, inline.} elif T.sizeof == 8: # uint64 or MpUint[uint32] # We cannot double uint64 to uint128 - # We use the Karatsuba algorithm - karatsubaImpl(cast[MpUint[uint32]](a), cast[MpUint[uint32]](b)) + naiveMulImpl(cast[MpUint[uint32]](a), cast[MpUint[uint32]](b)) else: # Case: at least uint128 * uint128 --> uint256 - karatsubaImpl(a, b) \ No newline at end of file + naiveMulImpl(a, b) \ No newline at end of file