We aren't using Karatsuba actually but school-grade naive mul
This commit is contained in:
parent
d60419a731
commit
994be7fa61
|
@ -36,15 +36,12 @@ proc `-`*[T: MpUint](a, b: T): T {.noSideEffect, noInit, inline.}=
|
||||||
result = a
|
result = a
|
||||||
result -= b
|
result -= b
|
||||||
|
|
||||||
proc karatsuba[T: BaseUint](a, b: T): MpUint[T] {.noSideEffect, noInit, inline.}
|
proc naiveMul[T: BaseUint](a, b: T): MpUint[T] {.noSideEffect, noInit, inline.}
|
||||||
# Forward declaration
|
# Forward declaration
|
||||||
|
|
||||||
proc `*`*[T: MpUint](a, b: T): T {.noSideEffect, noInit.}=
|
proc `*`*[T: MpUint](a, b: T): T {.noSideEffect, noInit.}=
|
||||||
## Multiplication for multi-precision unsigned uint
|
## Multiplication for multi-precision unsigned uint
|
||||||
#
|
#
|
||||||
# We use a modified Karatsuba algorithm
|
|
||||||
#
|
|
||||||
# Karatsuba algorithm splits the operand into `hi * B + lo`
|
|
||||||
# For our representation, it is similar to school grade multiplication
|
# For our representation, it is similar to school grade multiplication
|
||||||
# Consider hi and lo as if they were digits
|
# Consider hi and lo as if they were digits
|
||||||
#
|
#
|
||||||
|
@ -60,28 +57,40 @@ proc `*`*[T: MpUint](a, b: T): T {.noSideEffect, noInit.}=
|
||||||
#
|
#
|
||||||
# If T is a type
|
# If T is a type
|
||||||
# For T * T --> T we don't need to compute z2 as it always overflow
|
# For T * T --> T we don't need to compute z2 as it always overflow
|
||||||
# For T * T --> 2T (uint64 * uint64 --> uint128) we use the full precision Karatsuba algorithm
|
# For T * T --> 2T (uint64 * uint64 --> uint128) we use extra precision multiplication
|
||||||
|
|
||||||
result = karatsuba(a.lo, b.lo)
|
result = naiveMul(a.lo, b.lo)
|
||||||
result.hi += (karatsuba(a.hi, b.lo) + karatsuba(a.lo, b.hi)).lo
|
result.hi += (naiveMul(a.hi, b.lo) + naiveMul(a.lo, b.hi)).lo
|
||||||
|
|
||||||
template karatsubaImpl[T: MpUint](x, y: T): MpUint[T] =
|
template naiveMulImpl[T: MpUint](x, y: T): MpUint[T] =
|
||||||
|
# See details at
|
||||||
# https://en.wikipedia.org/wiki/Karatsuba_algorithm
|
# https://en.wikipedia.org/wiki/Karatsuba_algorithm
|
||||||
|
# https://locklessinc.com/articles/256bit_arithmetic/
|
||||||
|
# https://www.miracl.com/press/missing-a-trick-karatsuba-variations-michael-scott
|
||||||
|
#
|
||||||
|
# We use the naive school grade multiplication instead of Karatsuba I.e.
|
||||||
|
# z1 = x.hi * y.lo + x.lo * y.hi (Naive) = (x.lo - x.hi)(y.hi - y.lo) + z0 + z2 (Karatsuba)
|
||||||
|
#
|
||||||
|
# On modern architecture:
|
||||||
|
# - addition and multiplication have the same cost
|
||||||
|
# - Karatsuba would require to deal with potentially negative intermediate result
|
||||||
|
# and introduce branching
|
||||||
|
# - More total operations means more register moves
|
||||||
|
|
||||||
const halfShl = T.sizeof div 2
|
const halfShl = T.sizeof div 2
|
||||||
let
|
let
|
||||||
z0 = karatsuba(x.lo, y.lo)
|
z0 = naiveMul(x.lo, y.lo)
|
||||||
tmp = karatsuba(x.hi, y.lo)
|
tmp = naiveMul(x.hi, y.lo)
|
||||||
|
|
||||||
var z1 = tmp
|
var z1 = tmp
|
||||||
z1 += karatsuba(x.hi, y.lo)
|
z1 += naiveMul(x.hi, y.lo)
|
||||||
let z2 = (z1 < tmp).T + karatsuba(x.hi, y.hi)
|
let z2 = (z1 < tmp).T + naiveMul(x.hi, y.hi)
|
||||||
|
|
||||||
result.lo = z1.lo shl halfShl + z0
|
result.lo = z1.lo shl halfShl + z0
|
||||||
result.hi = z2 + z1.hi
|
result.hi = z2 + z1.hi
|
||||||
|
|
||||||
proc karatsuba[T: BaseUint](a, b: T): MpUint[T] {.noSideEffect, noInit, inline.}=
|
proc naiveMul[T: BaseUint](a, b: T): MpUint[T] {.noSideEffect, noInit, inline.}=
|
||||||
## Karatsuba algorithm with full precision
|
## Naive multiplication algorithm with extended precision
|
||||||
|
|
||||||
when T.sizeof in {1, 2, 4}:
|
when T.sizeof in {1, 2, 4}:
|
||||||
# Use types twice bigger to do the multiplication
|
# Use types twice bigger to do the multiplication
|
||||||
|
@ -89,8 +98,7 @@ proc karatsuba[T: BaseUint](a, b: T): MpUint[T] {.noSideEffect, noInit, inline.}
|
||||||
|
|
||||||
elif T.sizeof == 8: # uint64 or MpUint[uint32]
|
elif T.sizeof == 8: # uint64 or MpUint[uint32]
|
||||||
# We cannot double uint64 to uint128
|
# We cannot double uint64 to uint128
|
||||||
# We use the Karatsuba algorithm
|
naiveMulImpl(cast[MpUint[uint32]](a), cast[MpUint[uint32]](b))
|
||||||
karatsubaImpl(cast[MpUint[uint32]](a), cast[MpUint[uint32]](b))
|
|
||||||
else:
|
else:
|
||||||
# Case: at least uint128 * uint128 --> uint256
|
# Case: at least uint128 * uint128 --> uint256
|
||||||
karatsubaImpl(a, b)
|
naiveMulImpl(a, b)
|
Loading…
Reference in New Issue