Implement multiplication
This commit is contained in:
parent
206ffa92cf
commit
a0dec54c12
|
@ -105,3 +105,33 @@ iterator leastToMostSig*(cLimbs: var Limbs, aLimbs: Limbs, bLimbs: Limbs): (var
|
|||
else:
|
||||
for i in countdown(aLimbs.len-1, 0):
|
||||
yield (cLimbs[i], aLimbs[i], bLimbs[i])
|
||||
|
||||
import std/macros
|
||||
|
||||
proc replaceNodes(ast: NimNode, what: NimNode, by: NimNode): NimNode =
|
||||
# Replace "what" ident node by "by"
|
||||
proc inspect(node: NimNode): NimNode =
|
||||
case node.kind:
|
||||
of {nnkIdent, nnkSym}:
|
||||
if node.eqIdent(what):
|
||||
return by
|
||||
return node
|
||||
of nnkEmpty:
|
||||
return node
|
||||
of nnkLiterals:
|
||||
return node
|
||||
else:
|
||||
var rTree = node.kind.newTree()
|
||||
for child in node:
|
||||
rTree.add inspect(child)
|
||||
return rTree
|
||||
result = inspect(ast)
|
||||
|
||||
macro staticFor*(idx: untyped{nkIdent}, start, stopEx: static int, body: untyped): untyped =
|
||||
## staticFor [min inclusive, max exclusive)
|
||||
result = newStmtList()
|
||||
for i in start ..< stopEx:
|
||||
result.add nnkBlockStmt.newTree(
|
||||
ident("unrolledIter_" & $idx & $i),
|
||||
body.replaceNodes(idx, newLit i)
|
||||
)
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
import
|
||||
./datatypes, ./uint_comparison, ./uint_bitwise_ops,
|
||||
./datatypes,
|
||||
./primitives/addcarry_subborrow
|
||||
|
||||
# ############ Addition & Substraction ############ #
|
||||
|
|
|
@ -7,160 +7,78 @@
|
|||
#
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
import macros,
|
||||
./conversion,
|
||||
./initialization,
|
||||
./datatypes,
|
||||
./uint_comparison,
|
||||
./uint_addsub
|
||||
import
|
||||
./datatypes,
|
||||
./primitives/extended_precision
|
||||
|
||||
# ################### Multiplication ################### #
|
||||
{.push raises: [], gcsafe.}
|
||||
|
||||
func lo(x: uint64): uint64 {.inline.} =
|
||||
const
|
||||
p: uint64 = 32
|
||||
base: uint64 = 1'u64 shl p
|
||||
mask: uint64 = base - 1
|
||||
result = x and mask
|
||||
func prod*[rLen, aLen, bLen](r: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) =
|
||||
## Multi-precision multiplication
|
||||
## r <- a*b
|
||||
##
|
||||
## `a`, `b`, `r` can have a different number of limbs
|
||||
## if `r`.limbs.len < a.limbs.len + b.limbs.len
|
||||
## The result will be truncated, i.e. it will be
|
||||
## a * b (mod (2^WordBitwidth)^r.limbs.len)
|
||||
|
||||
func hi(x: uint64): uint64 {.inline.} =
|
||||
const
|
||||
p = 32
|
||||
result = x shr p
|
||||
# We use Product Scanning / Comba multiplication
|
||||
var t, u, v = Word(0)
|
||||
var z: Limbs[rLen] # zero-init, ensure on stack and removes in-place problems
|
||||
|
||||
# No generic, somehow Nim is given ambiguous call with the T: UintImpl overload
|
||||
func extPrecMul*(result: var UintImpl[uint8], x, y: uint8) {.inline.}=
|
||||
## Extended precision multiplication
|
||||
result = cast[type result](x.asDoubleUint * y.asDoubleUint)
|
||||
staticFor i, 0, min(a.len+b.len, r.len):
|
||||
const ib = min(b.len-1, i)
|
||||
const ia = i - ib
|
||||
staticFor j, 0, min(a.len - ia, ib+1):
|
||||
mulAcc(t, u, v, a[ia+j], b[ib-j])
|
||||
|
||||
func extPrecMul*(result: var UintImpl[uint16], x, y: uint16) {.inline.}=
|
||||
## Extended precision multiplication
|
||||
result = cast[type result](x.asDoubleUint * y.asDoubleUint)
|
||||
z[i] = v
|
||||
v = u
|
||||
u = t
|
||||
t = Word(0)
|
||||
|
||||
func extPrecMul*(result: var UintImpl[uint32], x, y: uint32) {.inline.}=
|
||||
## Extended precision multiplication
|
||||
result = cast[type result](x.asDoubleUint * y.asDoubleUint)
|
||||
r = z
|
||||
|
||||
func extPrecAddMul[T: uint8 or uint16 or uint32](result: var UintImpl[T], x, y: T) {.inline.}=
|
||||
## Extended precision fused in-place addition & multiplication
|
||||
result += cast[type result](x.asDoubleUint * y.asDoubleUint)
|
||||
|
||||
template extPrecMulImpl(result: var UintImpl[uint64], op: untyped, u, v: uint64) =
|
||||
const
|
||||
p = 64 div 2
|
||||
base: uint64 = 1'u64 shl p
|
||||
|
||||
var
|
||||
x0, x1, x2, x3: uint64
|
||||
|
||||
let
|
||||
ul = lo(u)
|
||||
uh = hi(u)
|
||||
vl = lo(v)
|
||||
vh = hi(v)
|
||||
|
||||
x0 = ul * vl
|
||||
x1 = ul * vh
|
||||
x2 = uh * vl
|
||||
x3 = uh * vh
|
||||
|
||||
x1 += hi(x0) # This can't carry
|
||||
x1 += x2 # but this can
|
||||
if x1 < x2: # if carry, add it to x3
|
||||
x3 += base
|
||||
|
||||
op(result.hi, x3 + hi(x1))
|
||||
op(result.lo, (x1 shl p) or lo(x0))
|
||||
|
||||
func extPrecMul*(result: var UintImpl[uint64], u, v: uint64) =
|
||||
## Extended precision multiplication
|
||||
extPrecMulImpl(result, `=`, u, v)
|
||||
|
||||
func extPrecAddMul(result: var UintImpl[uint64], u, v: uint64) =
|
||||
## Extended precision fused in-place addition & multiplication
|
||||
extPrecMulImpl(result, `+=`, u, v)
|
||||
|
||||
macro eqSym(x, y: untyped): untyped =
|
||||
let eq = $x == $y # Unfortunately eqIdent compares to string.
|
||||
result = newLit eq
|
||||
|
||||
func extPrecAddMul[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T])
|
||||
func extPrecMul*[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T])
|
||||
# Forward declaration
|
||||
|
||||
template extPrecMulImpl*[T](result: var UintImpl[UintImpl[T]], op: untyped, x, y: UintImpl[T]) =
|
||||
# See details at
|
||||
# https://en.wikipedia.org/wiki/Karatsuba_algorithm
|
||||
# https://locklessinc.com/articles/256bit_arithmetic/
|
||||
# https://www.miracl.com/press/missing-a-trick-karatsuba-variations-michael-scott
|
||||
func prod_high_words*[rLen, aLen, bLen](
|
||||
r: var Limbs[rLen],
|
||||
a: Limbs[aLen], b: Limbs[bLen],
|
||||
lowestWordIndex: static int) =
|
||||
## Multi-precision multiplication keeping only high words
|
||||
## r <- a*b >> (2^WordBitWidth)^lowestWordIndex
|
||||
##
|
||||
## `a`, `b`, `r` can have a different number of limbs
|
||||
## if `r`.limbs.len < a.limbs.len + b.limbs.len - lowestWordIndex
|
||||
## The result will be truncated, i.e. it will be
|
||||
## a * b >> (2^WordBitWidth)^lowestWordIndex (mod (2^WordBitwidth)^r.limbs.len)
|
||||
#
|
||||
# We use the naive school grade multiplication instead of Karatsuba I.e.
|
||||
# z1 = x.hi * y.lo + x.lo * y.hi (Naive) = (x.lo - x.hi)(y.hi - y.lo) + z0 + z2 (Karatsuba)
|
||||
#
|
||||
# On modern architecture:
|
||||
# - addition and multiplication have the same cost
|
||||
# - Karatsuba would require to deal with potentially negative intermediate result
|
||||
# and introduce branching
|
||||
# - More total operations means more register moves
|
||||
# This is useful for
|
||||
# - Barret reduction
|
||||
# - Approximating multiplication by a fractional constant in the form f(a) = K/C * a
|
||||
# with K and C known at compile-time.
|
||||
# We can instead find a well chosen M = (2^WordBitWidth)^w, with M > C (i.e. M is a power of 2 bigger than C)
|
||||
# Precompute P = K*M/C at compile-time
|
||||
# and at runtime do P*a/M <=> P*a >> (WordBitWidth*w)
|
||||
# i.e. prod_high_words(result, P, a, w)
|
||||
|
||||
var z1: type x
|
||||
# We use Product Scanning / Comba multiplication
|
||||
var t, u, v = Word(0) # Will raise warning on empty iterations
|
||||
var z: Limbs[rLen] # zero-init, ensure on stack and removes in-place problems
|
||||
|
||||
# Low part and hi part - z0 & z2
|
||||
when eqSym(op, `+=`):
|
||||
extPrecAddMul(result.lo, x.lo, y.lo)
|
||||
extPrecAddMul(result.hi, x.hi, y.hi)
|
||||
else:
|
||||
extPrecMul(result.lo, x.lo, y.lo)
|
||||
extPrecMul(result.hi, x.hi, y.hi)
|
||||
# The previous 2 columns can affect the lowest word due to carries
|
||||
# but not the ones before (we accumulate in 3 words (t, u, v))
|
||||
const w = lowestWordIndex - 2
|
||||
|
||||
## TODO - fuse those parts and reduce the number of carry checks
|
||||
# Middle part - z1 - 1st mul
|
||||
extPrecMul(z1, x.hi, y.lo)
|
||||
result.lo.hi += z1.lo
|
||||
if result.lo.hi < z1.lo:
|
||||
inc result.hi
|
||||
staticFor i, max(0, w), min(a.len+b.len, r.len+lowestWordIndex):
|
||||
const ib = min(b.len-1, i)
|
||||
const ia = i - ib
|
||||
staticFor j, 0, min(a.len - ia, ib+1):
|
||||
mulAcc(t, u, v, a[ia+j], b[ib-j])
|
||||
|
||||
result.hi.lo += z1.hi
|
||||
if result.hi.lo < z1.hi:
|
||||
inc result.hi.hi
|
||||
when i >= lowestWordIndex:
|
||||
z[i-lowestWordIndex] = v
|
||||
v = u
|
||||
u = t
|
||||
t = Word(0)
|
||||
|
||||
# Middle part - z1 - 2nd mul
|
||||
extPrecMul(z1, x.lo, y.hi)
|
||||
result.lo.hi += z1.lo
|
||||
if result.lo.hi < z1.lo:
|
||||
inc result.hi
|
||||
|
||||
result.hi.lo += z1.hi
|
||||
if result.hi.lo < z1.hi:
|
||||
inc result.hi.hi
|
||||
|
||||
func extPrecAddMul[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T]) =
|
||||
## Extended precision fused in-place addition & multiplication
|
||||
extPrecMulImpl(result, `+=`, u, v)
|
||||
|
||||
func extPrecMul*[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T]) =
|
||||
## Extended precision multiplication
|
||||
extPrecMulImpl(result, `=`, u, v)
|
||||
|
||||
func `*`*[T](x, y: UintImpl[T]): UintImpl[T] {.inline.}=
|
||||
## Multiplication for multi-precision unsigned uint
|
||||
#
|
||||
# For our representation, it is similar to school grade multiplication
|
||||
# Consider hi and lo as if they were digits
|
||||
#
|
||||
# 12
|
||||
# X 15
|
||||
# ------
|
||||
# 10 lo*lo -> z0
|
||||
# 5 hi*lo -> z1
|
||||
# 2 lo*hi -> z1
|
||||
# 10 hi*hi -- z2
|
||||
# ------
|
||||
# 180
|
||||
#
|
||||
# If T is a type
|
||||
# For T * T --> T we don't need to compute z2 as it always overflow
|
||||
# For T * T --> 2T (uint64 * uint64 --> uint128) we use extra precision multiplication
|
||||
|
||||
extPrecMul(result, x.lo, y.lo)
|
||||
result.hi += x.lo * y.hi + x.hi * y.lo
|
||||
r = z
|
||||
|
|
Loading…
Reference in New Issue