Implement multiplication

This commit is contained in:
Mamy André-Ratsimbazafy 2020-06-12 20:05:40 +02:00 committed by jangko
parent 206ffa92cf
commit a0dec54c12
No known key found for this signature in database
GPG Key ID: 31702AE10541E6B9
3 changed files with 92 additions and 144 deletions

View File

@ -105,3 +105,33 @@ iterator leastToMostSig*(cLimbs: var Limbs, aLimbs: Limbs, bLimbs: Limbs): (var
else:
for i in countdown(aLimbs.len-1, 0):
yield (cLimbs[i], aLimbs[i], bLimbs[i])
import std/macros
proc replaceNodes(ast: NimNode, what: NimNode, by: NimNode): NimNode =
# Replace "what" ident node by "by"
proc inspect(node: NimNode): NimNode =
case node.kind:
of {nnkIdent, nnkSym}:
if node.eqIdent(what):
return by
return node
of nnkEmpty:
return node
of nnkLiterals:
return node
else:
var rTree = node.kind.newTree()
for child in node:
rTree.add inspect(child)
return rTree
result = inspect(ast)
macro staticFor*(idx: untyped{nkIdent}, start, stopEx: static int, body: untyped): untyped =
## staticFor [min inclusive, max exclusive)
result = newStmtList()
for i in start ..< stopEx:
result.add nnkBlockStmt.newTree(
ident("unrolledIter_" & $idx & $i),
body.replaceNodes(idx, newLit i)
)

View File

@ -8,7 +8,7 @@
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
./datatypes, ./uint_comparison, ./uint_bitwise_ops,
./datatypes,
./primitives/addcarry_subborrow
# ############ Addition & Substraction ############ #

View File

@ -7,160 +7,78 @@
#
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import macros,
./conversion,
./initialization,
./datatypes,
./uint_comparison,
./uint_addsub
import
./datatypes,
./primitives/extended_precision
# ################### Multiplication ################### #
{.push raises: [], gcsafe.}
func lo(x: uint64): uint64 {.inline.} =
const
p: uint64 = 32
base: uint64 = 1'u64 shl p
mask: uint64 = base - 1
result = x and mask
func prod*[rLen, aLen, bLen](r: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) =
## Multi-precision multiplication
## r <- a*b
##
## `a`, `b`, `r` can have a different number of limbs
## if `r`.limbs.len < a.limbs.len + b.limbs.len
## The result will be truncated, i.e. it will be
## a * b (mod (2^WordBitwidth)^r.limbs.len)
func hi(x: uint64): uint64 {.inline.} =
const
p = 32
result = x shr p
# We use Product Scanning / Comba multiplication
var t, u, v = Word(0)
var z: Limbs[rLen] # zero-init, ensure on stack and removes in-place problems
# No generic, somehow Nim is given ambiguous call with the T: UintImpl overload
func extPrecMul*(result: var UintImpl[uint8], x, y: uint8) {.inline.}=
## Extended precision multiplication
result = cast[type result](x.asDoubleUint * y.asDoubleUint)
staticFor i, 0, min(a.len+b.len, r.len):
const ib = min(b.len-1, i)
const ia = i - ib
staticFor j, 0, min(a.len - ia, ib+1):
mulAcc(t, u, v, a[ia+j], b[ib-j])
func extPrecMul*(result: var UintImpl[uint16], x, y: uint16) {.inline.}=
## Extended precision multiplication
result = cast[type result](x.asDoubleUint * y.asDoubleUint)
z[i] = v
v = u
u = t
t = Word(0)
func extPrecMul*(result: var UintImpl[uint32], x, y: uint32) {.inline.}=
## Extended precision multiplication
result = cast[type result](x.asDoubleUint * y.asDoubleUint)
r = z
func extPrecAddMul[T: uint8 or uint16 or uint32](result: var UintImpl[T], x, y: T) {.inline.}=
## Extended precision fused in-place addition & multiplication
result += cast[type result](x.asDoubleUint * y.asDoubleUint)
template extPrecMulImpl(result: var UintImpl[uint64], op: untyped, u, v: uint64) =
const
p = 64 div 2
base: uint64 = 1'u64 shl p
var
x0, x1, x2, x3: uint64
let
ul = lo(u)
uh = hi(u)
vl = lo(v)
vh = hi(v)
x0 = ul * vl
x1 = ul * vh
x2 = uh * vl
x3 = uh * vh
x1 += hi(x0) # This can't carry
x1 += x2 # but this can
if x1 < x2: # if carry, add it to x3
x3 += base
op(result.hi, x3 + hi(x1))
op(result.lo, (x1 shl p) or lo(x0))
func extPrecMul*(result: var UintImpl[uint64], u, v: uint64) =
## Extended precision multiplication
extPrecMulImpl(result, `=`, u, v)
func extPrecAddMul(result: var UintImpl[uint64], u, v: uint64) =
## Extended precision fused in-place addition & multiplication
extPrecMulImpl(result, `+=`, u, v)
macro eqSym(x, y: untyped): untyped =
let eq = $x == $y # Unfortunately eqIdent compares to string.
result = newLit eq
func extPrecAddMul[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T])
func extPrecMul*[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T])
# Forward declaration
template extPrecMulImpl*[T](result: var UintImpl[UintImpl[T]], op: untyped, x, y: UintImpl[T]) =
# See details at
# https://en.wikipedia.org/wiki/Karatsuba_algorithm
# https://locklessinc.com/articles/256bit_arithmetic/
# https://www.miracl.com/press/missing-a-trick-karatsuba-variations-michael-scott
func prod_high_words*[rLen, aLen, bLen](
r: var Limbs[rLen],
a: Limbs[aLen], b: Limbs[bLen],
lowestWordIndex: static int) =
## Multi-precision multiplication keeping only high words
## r <- a*b >> (2^WordBitWidth)^lowestWordIndex
##
## `a`, `b`, `r` can have a different number of limbs
## if `r`.limbs.len < a.limbs.len + b.limbs.len - lowestWordIndex
## The result will be truncated, i.e. it will be
## a * b >> (2^WordBitWidth)^lowestWordIndex (mod (2^WordBitwidth)^r.limbs.len)
#
# We use the naive school grade multiplication instead of Karatsuba I.e.
# z1 = x.hi * y.lo + x.lo * y.hi (Naive) = (x.lo - x.hi)(y.hi - y.lo) + z0 + z2 (Karatsuba)
#
# On modern architecture:
# - addition and multiplication have the same cost
# - Karatsuba would require to deal with potentially negative intermediate result
# and introduce branching
# - More total operations means more register moves
# This is useful for
# - Barret reduction
# - Approximating multiplication by a fractional constant in the form f(a) = K/C * a
# with K and C known at compile-time.
# We can instead find a well chosen M = (2^WordBitWidth)^w, with M > C (i.e. M is a power of 2 bigger than C)
# Precompute P = K*M/C at compile-time
# and at runtime do P*a/M <=> P*a >> (WordBitWidth*w)
# i.e. prod_high_words(result, P, a, w)
var z1: type x
# We use Product Scanning / Comba multiplication
var t, u, v = Word(0) # Will raise warning on empty iterations
var z: Limbs[rLen] # zero-init, ensure on stack and removes in-place problems
# Low part and hi part - z0 & z2
when eqSym(op, `+=`):
extPrecAddMul(result.lo, x.lo, y.lo)
extPrecAddMul(result.hi, x.hi, y.hi)
else:
extPrecMul(result.lo, x.lo, y.lo)
extPrecMul(result.hi, x.hi, y.hi)
# The previous 2 columns can affect the lowest word due to carries
# but not the ones before (we accumulate in 3 words (t, u, v))
const w = lowestWordIndex - 2
## TODO - fuse those parts and reduce the number of carry checks
# Middle part - z1 - 1st mul
extPrecMul(z1, x.hi, y.lo)
result.lo.hi += z1.lo
if result.lo.hi < z1.lo:
inc result.hi
staticFor i, max(0, w), min(a.len+b.len, r.len+lowestWordIndex):
const ib = min(b.len-1, i)
const ia = i - ib
staticFor j, 0, min(a.len - ia, ib+1):
mulAcc(t, u, v, a[ia+j], b[ib-j])
result.hi.lo += z1.hi
if result.hi.lo < z1.hi:
inc result.hi.hi
when i >= lowestWordIndex:
z[i-lowestWordIndex] = v
v = u
u = t
t = Word(0)
# Middle part - z1 - 2nd mul
extPrecMul(z1, x.lo, y.hi)
result.lo.hi += z1.lo
if result.lo.hi < z1.lo:
inc result.hi
result.hi.lo += z1.hi
if result.hi.lo < z1.hi:
inc result.hi.hi
func extPrecAddMul[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T]) =
## Extended precision fused in-place addition & multiplication
extPrecMulImpl(result, `+=`, u, v)
func extPrecMul*[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T]) =
## Extended precision multiplication
extPrecMulImpl(result, `=`, u, v)
func `*`*[T](x, y: UintImpl[T]): UintImpl[T] {.inline.}=
## Multiplication for multi-precision unsigned uint
#
# For our representation, it is similar to school grade multiplication
# Consider hi and lo as if they were digits
#
# 12
# X 15
# ------
# 10 lo*lo -> z0
# 5 hi*lo -> z1
# 2 lo*hi -> z1
# 10 hi*hi -- z2
# ------
# 180
#
# If T is a type
# For T * T --> T we don't need to compute z2 as it always overflow
# For T * T --> 2T (uint64 * uint64 --> uint128) we use extra precision multiplication
extPrecMul(result, x.lo, y.lo)
result.hi += x.lo * y.hi + x.hi * y.lo
r = z