From a0dec54c12926bac09d615400e511a9884febb13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Fri, 12 Jun 2020 20:05:40 +0200 Subject: [PATCH] Implement multiplication --- stint/private/datatypes.nim | 30 +++++ stint/private/uint_addsub.nim | 2 +- stint/private/uint_mul.nim | 204 ++++++++++------------------------ 3 files changed, 92 insertions(+), 144 deletions(-) diff --git a/stint/private/datatypes.nim b/stint/private/datatypes.nim index 987af07..4fa1cfb 100644 --- a/stint/private/datatypes.nim +++ b/stint/private/datatypes.nim @@ -105,3 +105,33 @@ iterator leastToMostSig*(cLimbs: var Limbs, aLimbs: Limbs, bLimbs: Limbs): (var else: for i in countdown(aLimbs.len-1, 0): yield (cLimbs[i], aLimbs[i], bLimbs[i]) + +import std/macros + +proc replaceNodes(ast: NimNode, what: NimNode, by: NimNode): NimNode = + # Replace "what" ident node by "by" + proc inspect(node: NimNode): NimNode = + case node.kind: + of {nnkIdent, nnkSym}: + if node.eqIdent(what): + return by + return node + of nnkEmpty: + return node + of nnkLiterals: + return node + else: + var rTree = node.kind.newTree() + for child in node: + rTree.add inspect(child) + return rTree + result = inspect(ast) + +macro staticFor*(idx: untyped{nkIdent}, start, stopEx: static int, body: untyped): untyped = + ## staticFor [min inclusive, max exclusive) + result = newStmtList() + for i in start ..< stopEx: + result.add nnkBlockStmt.newTree( + ident("unrolledIter_" & $idx & $i), + body.replaceNodes(idx, newLit i) + ) diff --git a/stint/private/uint_addsub.nim b/stint/private/uint_addsub.nim index 135a1b2..c99795e 100644 --- a/stint/private/uint_addsub.nim +++ b/stint/private/uint_addsub.nim @@ -8,7 +8,7 @@ # at your option. This file may not be copied, modified, or distributed except according to those terms. import - ./datatypes, ./uint_comparison, ./uint_bitwise_ops, + ./datatypes, ./primitives/addcarry_subborrow # ############ Addition & Substraction ############ # diff --git a/stint/private/uint_mul.nim b/stint/private/uint_mul.nim index 055289c..022cef5 100644 --- a/stint/private/uint_mul.nim +++ b/stint/private/uint_mul.nim @@ -7,160 +7,78 @@ # # at your option. This file may not be copied, modified, or distributed except according to those terms. -import macros, - ./conversion, - ./initialization, - ./datatypes, - ./uint_comparison, - ./uint_addsub +import + ./datatypes, + ./primitives/extended_precision # ################### Multiplication ################### # +{.push raises: [], gcsafe.} -func lo(x: uint64): uint64 {.inline.} = - const - p: uint64 = 32 - base: uint64 = 1'u64 shl p - mask: uint64 = base - 1 - result = x and mask +func prod*[rLen, aLen, bLen](r: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) = + ## Multi-precision multiplication + ## r <- a*b + ## + ## `a`, `b`, `r` can have a different number of limbs + ## if `r`.limbs.len < a.limbs.len + b.limbs.len + ## The result will be truncated, i.e. it will be + ## a * b (mod (2^WordBitwidth)^r.limbs.len) -func hi(x: uint64): uint64 {.inline.} = - const - p = 32 - result = x shr p + # We use Product Scanning / Comba multiplication + var t, u, v = Word(0) + var z: Limbs[rLen] # zero-init, ensure on stack and removes in-place problems -# No generic, somehow Nim is given ambiguous call with the T: UintImpl overload -func extPrecMul*(result: var UintImpl[uint8], x, y: uint8) {.inline.}= - ## Extended precision multiplication - result = cast[type result](x.asDoubleUint * y.asDoubleUint) + staticFor i, 0, min(a.len+b.len, r.len): + const ib = min(b.len-1, i) + const ia = i - ib + staticFor j, 0, min(a.len - ia, ib+1): + mulAcc(t, u, v, a[ia+j], b[ib-j]) -func extPrecMul*(result: var UintImpl[uint16], x, y: uint16) {.inline.}= - ## Extended precision multiplication - result = cast[type result](x.asDoubleUint * y.asDoubleUint) + z[i] = v + v = u + u = t + t = Word(0) -func extPrecMul*(result: var UintImpl[uint32], x, y: uint32) {.inline.}= - ## Extended precision multiplication - result = cast[type result](x.asDoubleUint * y.asDoubleUint) + r = z -func extPrecAddMul[T: uint8 or uint16 or uint32](result: var UintImpl[T], x, y: T) {.inline.}= - ## Extended precision fused in-place addition & multiplication - result += cast[type result](x.asDoubleUint * y.asDoubleUint) - -template extPrecMulImpl(result: var UintImpl[uint64], op: untyped, u, v: uint64) = - const - p = 64 div 2 - base: uint64 = 1'u64 shl p - - var - x0, x1, x2, x3: uint64 - - let - ul = lo(u) - uh = hi(u) - vl = lo(v) - vh = hi(v) - - x0 = ul * vl - x1 = ul * vh - x2 = uh * vl - x3 = uh * vh - - x1 += hi(x0) # This can't carry - x1 += x2 # but this can - if x1 < x2: # if carry, add it to x3 - x3 += base - - op(result.hi, x3 + hi(x1)) - op(result.lo, (x1 shl p) or lo(x0)) - -func extPrecMul*(result: var UintImpl[uint64], u, v: uint64) = - ## Extended precision multiplication - extPrecMulImpl(result, `=`, u, v) - -func extPrecAddMul(result: var UintImpl[uint64], u, v: uint64) = - ## Extended precision fused in-place addition & multiplication - extPrecMulImpl(result, `+=`, u, v) - -macro eqSym(x, y: untyped): untyped = - let eq = $x == $y # Unfortunately eqIdent compares to string. - result = newLit eq - -func extPrecAddMul[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T]) -func extPrecMul*[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T]) - # Forward declaration - -template extPrecMulImpl*[T](result: var UintImpl[UintImpl[T]], op: untyped, x, y: UintImpl[T]) = - # See details at - # https://en.wikipedia.org/wiki/Karatsuba_algorithm - # https://locklessinc.com/articles/256bit_arithmetic/ - # https://www.miracl.com/press/missing-a-trick-karatsuba-variations-michael-scott +func prod_high_words*[rLen, aLen, bLen]( + r: var Limbs[rLen], + a: Limbs[aLen], b: Limbs[bLen], + lowestWordIndex: static int) = + ## Multi-precision multiplication keeping only high words + ## r <- a*b >> (2^WordBitWidth)^lowestWordIndex + ## + ## `a`, `b`, `r` can have a different number of limbs + ## if `r`.limbs.len < a.limbs.len + b.limbs.len - lowestWordIndex + ## The result will be truncated, i.e. it will be + ## a * b >> (2^WordBitWidth)^lowestWordIndex (mod (2^WordBitwidth)^r.limbs.len) # - # We use the naive school grade multiplication instead of Karatsuba I.e. - # z1 = x.hi * y.lo + x.lo * y.hi (Naive) = (x.lo - x.hi)(y.hi - y.lo) + z0 + z2 (Karatsuba) - # - # On modern architecture: - # - addition and multiplication have the same cost - # - Karatsuba would require to deal with potentially negative intermediate result - # and introduce branching - # - More total operations means more register moves + # This is useful for + # - Barret reduction + # - Approximating multiplication by a fractional constant in the form f(a) = K/C * a + # with K and C known at compile-time. + # We can instead find a well chosen M = (2^WordBitWidth)^w, with M > C (i.e. M is a power of 2 bigger than C) + # Precompute P = K*M/C at compile-time + # and at runtime do P*a/M <=> P*a >> (WordBitWidth*w) + # i.e. prod_high_words(result, P, a, w) - var z1: type x + # We use Product Scanning / Comba multiplication + var t, u, v = Word(0) # Will raise warning on empty iterations + var z: Limbs[rLen] # zero-init, ensure on stack and removes in-place problems - # Low part and hi part - z0 & z2 - when eqSym(op, `+=`): - extPrecAddMul(result.lo, x.lo, y.lo) - extPrecAddMul(result.hi, x.hi, y.hi) - else: - extPrecMul(result.lo, x.lo, y.lo) - extPrecMul(result.hi, x.hi, y.hi) + # The previous 2 columns can affect the lowest word due to carries + # but not the ones before (we accumulate in 3 words (t, u, v)) + const w = lowestWordIndex - 2 - ## TODO - fuse those parts and reduce the number of carry checks - # Middle part - z1 - 1st mul - extPrecMul(z1, x.hi, y.lo) - result.lo.hi += z1.lo - if result.lo.hi < z1.lo: - inc result.hi + staticFor i, max(0, w), min(a.len+b.len, r.len+lowestWordIndex): + const ib = min(b.len-1, i) + const ia = i - ib + staticFor j, 0, min(a.len - ia, ib+1): + mulAcc(t, u, v, a[ia+j], b[ib-j]) - result.hi.lo += z1.hi - if result.hi.lo < z1.hi: - inc result.hi.hi + when i >= lowestWordIndex: + z[i-lowestWordIndex] = v + v = u + u = t + t = Word(0) - # Middle part - z1 - 2nd mul - extPrecMul(z1, x.lo, y.hi) - result.lo.hi += z1.lo - if result.lo.hi < z1.lo: - inc result.hi - - result.hi.lo += z1.hi - if result.hi.lo < z1.hi: - inc result.hi.hi - -func extPrecAddMul[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T]) = - ## Extended precision fused in-place addition & multiplication - extPrecMulImpl(result, `+=`, u, v) - -func extPrecMul*[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T]) = - ## Extended precision multiplication - extPrecMulImpl(result, `=`, u, v) - -func `*`*[T](x, y: UintImpl[T]): UintImpl[T] {.inline.}= - ## Multiplication for multi-precision unsigned uint - # - # For our representation, it is similar to school grade multiplication - # Consider hi and lo as if they were digits - # - # 12 - # X 15 - # ------ - # 10 lo*lo -> z0 - # 5 hi*lo -> z1 - # 2 lo*hi -> z1 - # 10 hi*hi -- z2 - # ------ - # 180 - # - # If T is a type - # For T * T --> T we don't need to compute z2 as it always overflow - # For T * T --> 2T (uint64 * uint64 --> uint128) we use extra precision multiplication - - extPrecMul(result, x.lo, y.lo) - result.hi += x.lo * y.hi + x.hi * y.lo + r = z