mirror of
https://github.com/status-im/nim-stint.git
synced 2025-02-18 01:47:27 +00:00
Implement multiplication
This commit is contained in:
parent
206ffa92cf
commit
a0dec54c12
@ -105,3 +105,33 @@ iterator leastToMostSig*(cLimbs: var Limbs, aLimbs: Limbs, bLimbs: Limbs): (var
|
|||||||
else:
|
else:
|
||||||
for i in countdown(aLimbs.len-1, 0):
|
for i in countdown(aLimbs.len-1, 0):
|
||||||
yield (cLimbs[i], aLimbs[i], bLimbs[i])
|
yield (cLimbs[i], aLimbs[i], bLimbs[i])
|
||||||
|
|
||||||
|
import std/macros
|
||||||
|
|
||||||
|
proc replaceNodes(ast: NimNode, what: NimNode, by: NimNode): NimNode =
|
||||||
|
# Replace "what" ident node by "by"
|
||||||
|
proc inspect(node: NimNode): NimNode =
|
||||||
|
case node.kind:
|
||||||
|
of {nnkIdent, nnkSym}:
|
||||||
|
if node.eqIdent(what):
|
||||||
|
return by
|
||||||
|
return node
|
||||||
|
of nnkEmpty:
|
||||||
|
return node
|
||||||
|
of nnkLiterals:
|
||||||
|
return node
|
||||||
|
else:
|
||||||
|
var rTree = node.kind.newTree()
|
||||||
|
for child in node:
|
||||||
|
rTree.add inspect(child)
|
||||||
|
return rTree
|
||||||
|
result = inspect(ast)
|
||||||
|
|
||||||
|
macro staticFor*(idx: untyped{nkIdent}, start, stopEx: static int, body: untyped): untyped =
|
||||||
|
## staticFor [min inclusive, max exclusive)
|
||||||
|
result = newStmtList()
|
||||||
|
for i in start ..< stopEx:
|
||||||
|
result.add nnkBlockStmt.newTree(
|
||||||
|
ident("unrolledIter_" & $idx & $i),
|
||||||
|
body.replaceNodes(idx, newLit i)
|
||||||
|
)
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||||
|
|
||||||
import
|
import
|
||||||
./datatypes, ./uint_comparison, ./uint_bitwise_ops,
|
./datatypes,
|
||||||
./primitives/addcarry_subborrow
|
./primitives/addcarry_subborrow
|
||||||
|
|
||||||
# ############ Addition & Substraction ############ #
|
# ############ Addition & Substraction ############ #
|
||||||
|
@ -7,160 +7,78 @@
|
|||||||
#
|
#
|
||||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||||
|
|
||||||
import macros,
|
import
|
||||||
./conversion,
|
./datatypes,
|
||||||
./initialization,
|
./primitives/extended_precision
|
||||||
./datatypes,
|
|
||||||
./uint_comparison,
|
|
||||||
./uint_addsub
|
|
||||||
|
|
||||||
# ################### Multiplication ################### #
|
# ################### Multiplication ################### #
|
||||||
|
{.push raises: [], gcsafe.}
|
||||||
|
|
||||||
func lo(x: uint64): uint64 {.inline.} =
|
func prod*[rLen, aLen, bLen](r: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) =
|
||||||
const
|
## Multi-precision multiplication
|
||||||
p: uint64 = 32
|
## r <- a*b
|
||||||
base: uint64 = 1'u64 shl p
|
##
|
||||||
mask: uint64 = base - 1
|
## `a`, `b`, `r` can have a different number of limbs
|
||||||
result = x and mask
|
## if `r`.limbs.len < a.limbs.len + b.limbs.len
|
||||||
|
## The result will be truncated, i.e. it will be
|
||||||
|
## a * b (mod (2^WordBitwidth)^r.limbs.len)
|
||||||
|
|
||||||
func hi(x: uint64): uint64 {.inline.} =
|
# We use Product Scanning / Comba multiplication
|
||||||
const
|
var t, u, v = Word(0)
|
||||||
p = 32
|
var z: Limbs[rLen] # zero-init, ensure on stack and removes in-place problems
|
||||||
result = x shr p
|
|
||||||
|
|
||||||
# No generic, somehow Nim is given ambiguous call with the T: UintImpl overload
|
staticFor i, 0, min(a.len+b.len, r.len):
|
||||||
func extPrecMul*(result: var UintImpl[uint8], x, y: uint8) {.inline.}=
|
const ib = min(b.len-1, i)
|
||||||
## Extended precision multiplication
|
const ia = i - ib
|
||||||
result = cast[type result](x.asDoubleUint * y.asDoubleUint)
|
staticFor j, 0, min(a.len - ia, ib+1):
|
||||||
|
mulAcc(t, u, v, a[ia+j], b[ib-j])
|
||||||
|
|
||||||
func extPrecMul*(result: var UintImpl[uint16], x, y: uint16) {.inline.}=
|
z[i] = v
|
||||||
## Extended precision multiplication
|
v = u
|
||||||
result = cast[type result](x.asDoubleUint * y.asDoubleUint)
|
u = t
|
||||||
|
t = Word(0)
|
||||||
|
|
||||||
func extPrecMul*(result: var UintImpl[uint32], x, y: uint32) {.inline.}=
|
r = z
|
||||||
## Extended precision multiplication
|
|
||||||
result = cast[type result](x.asDoubleUint * y.asDoubleUint)
|
|
||||||
|
|
||||||
func extPrecAddMul[T: uint8 or uint16 or uint32](result: var UintImpl[T], x, y: T) {.inline.}=
|
func prod_high_words*[rLen, aLen, bLen](
|
||||||
## Extended precision fused in-place addition & multiplication
|
r: var Limbs[rLen],
|
||||||
result += cast[type result](x.asDoubleUint * y.asDoubleUint)
|
a: Limbs[aLen], b: Limbs[bLen],
|
||||||
|
lowestWordIndex: static int) =
|
||||||
template extPrecMulImpl(result: var UintImpl[uint64], op: untyped, u, v: uint64) =
|
## Multi-precision multiplication keeping only high words
|
||||||
const
|
## r <- a*b >> (2^WordBitWidth)^lowestWordIndex
|
||||||
p = 64 div 2
|
##
|
||||||
base: uint64 = 1'u64 shl p
|
## `a`, `b`, `r` can have a different number of limbs
|
||||||
|
## if `r`.limbs.len < a.limbs.len + b.limbs.len - lowestWordIndex
|
||||||
var
|
## The result will be truncated, i.e. it will be
|
||||||
x0, x1, x2, x3: uint64
|
## a * b >> (2^WordBitWidth)^lowestWordIndex (mod (2^WordBitwidth)^r.limbs.len)
|
||||||
|
|
||||||
let
|
|
||||||
ul = lo(u)
|
|
||||||
uh = hi(u)
|
|
||||||
vl = lo(v)
|
|
||||||
vh = hi(v)
|
|
||||||
|
|
||||||
x0 = ul * vl
|
|
||||||
x1 = ul * vh
|
|
||||||
x2 = uh * vl
|
|
||||||
x3 = uh * vh
|
|
||||||
|
|
||||||
x1 += hi(x0) # This can't carry
|
|
||||||
x1 += x2 # but this can
|
|
||||||
if x1 < x2: # if carry, add it to x3
|
|
||||||
x3 += base
|
|
||||||
|
|
||||||
op(result.hi, x3 + hi(x1))
|
|
||||||
op(result.lo, (x1 shl p) or lo(x0))
|
|
||||||
|
|
||||||
func extPrecMul*(result: var UintImpl[uint64], u, v: uint64) =
|
|
||||||
## Extended precision multiplication
|
|
||||||
extPrecMulImpl(result, `=`, u, v)
|
|
||||||
|
|
||||||
func extPrecAddMul(result: var UintImpl[uint64], u, v: uint64) =
|
|
||||||
## Extended precision fused in-place addition & multiplication
|
|
||||||
extPrecMulImpl(result, `+=`, u, v)
|
|
||||||
|
|
||||||
macro eqSym(x, y: untyped): untyped =
|
|
||||||
let eq = $x == $y # Unfortunately eqIdent compares to string.
|
|
||||||
result = newLit eq
|
|
||||||
|
|
||||||
func extPrecAddMul[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T])
|
|
||||||
func extPrecMul*[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T])
|
|
||||||
# Forward declaration
|
|
||||||
|
|
||||||
template extPrecMulImpl*[T](result: var UintImpl[UintImpl[T]], op: untyped, x, y: UintImpl[T]) =
|
|
||||||
# See details at
|
|
||||||
# https://en.wikipedia.org/wiki/Karatsuba_algorithm
|
|
||||||
# https://locklessinc.com/articles/256bit_arithmetic/
|
|
||||||
# https://www.miracl.com/press/missing-a-trick-karatsuba-variations-michael-scott
|
|
||||||
#
|
#
|
||||||
# We use the naive school grade multiplication instead of Karatsuba I.e.
|
# This is useful for
|
||||||
# z1 = x.hi * y.lo + x.lo * y.hi (Naive) = (x.lo - x.hi)(y.hi - y.lo) + z0 + z2 (Karatsuba)
|
# - Barret reduction
|
||||||
#
|
# - Approximating multiplication by a fractional constant in the form f(a) = K/C * a
|
||||||
# On modern architecture:
|
# with K and C known at compile-time.
|
||||||
# - addition and multiplication have the same cost
|
# We can instead find a well chosen M = (2^WordBitWidth)^w, with M > C (i.e. M is a power of 2 bigger than C)
|
||||||
# - Karatsuba would require to deal with potentially negative intermediate result
|
# Precompute P = K*M/C at compile-time
|
||||||
# and introduce branching
|
# and at runtime do P*a/M <=> P*a >> (WordBitWidth*w)
|
||||||
# - More total operations means more register moves
|
# i.e. prod_high_words(result, P, a, w)
|
||||||
|
|
||||||
var z1: type x
|
# We use Product Scanning / Comba multiplication
|
||||||
|
var t, u, v = Word(0) # Will raise warning on empty iterations
|
||||||
|
var z: Limbs[rLen] # zero-init, ensure on stack and removes in-place problems
|
||||||
|
|
||||||
# Low part and hi part - z0 & z2
|
# The previous 2 columns can affect the lowest word due to carries
|
||||||
when eqSym(op, `+=`):
|
# but not the ones before (we accumulate in 3 words (t, u, v))
|
||||||
extPrecAddMul(result.lo, x.lo, y.lo)
|
const w = lowestWordIndex - 2
|
||||||
extPrecAddMul(result.hi, x.hi, y.hi)
|
|
||||||
else:
|
|
||||||
extPrecMul(result.lo, x.lo, y.lo)
|
|
||||||
extPrecMul(result.hi, x.hi, y.hi)
|
|
||||||
|
|
||||||
## TODO - fuse those parts and reduce the number of carry checks
|
staticFor i, max(0, w), min(a.len+b.len, r.len+lowestWordIndex):
|
||||||
# Middle part - z1 - 1st mul
|
const ib = min(b.len-1, i)
|
||||||
extPrecMul(z1, x.hi, y.lo)
|
const ia = i - ib
|
||||||
result.lo.hi += z1.lo
|
staticFor j, 0, min(a.len - ia, ib+1):
|
||||||
if result.lo.hi < z1.lo:
|
mulAcc(t, u, v, a[ia+j], b[ib-j])
|
||||||
inc result.hi
|
|
||||||
|
|
||||||
result.hi.lo += z1.hi
|
when i >= lowestWordIndex:
|
||||||
if result.hi.lo < z1.hi:
|
z[i-lowestWordIndex] = v
|
||||||
inc result.hi.hi
|
v = u
|
||||||
|
u = t
|
||||||
|
t = Word(0)
|
||||||
|
|
||||||
# Middle part - z1 - 2nd mul
|
r = z
|
||||||
extPrecMul(z1, x.lo, y.hi)
|
|
||||||
result.lo.hi += z1.lo
|
|
||||||
if result.lo.hi < z1.lo:
|
|
||||||
inc result.hi
|
|
||||||
|
|
||||||
result.hi.lo += z1.hi
|
|
||||||
if result.hi.lo < z1.hi:
|
|
||||||
inc result.hi.hi
|
|
||||||
|
|
||||||
func extPrecAddMul[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T]) =
|
|
||||||
## Extended precision fused in-place addition & multiplication
|
|
||||||
extPrecMulImpl(result, `+=`, u, v)
|
|
||||||
|
|
||||||
func extPrecMul*[T](result: var UintImpl[UintImpl[T]], u, v: UintImpl[T]) =
|
|
||||||
## Extended precision multiplication
|
|
||||||
extPrecMulImpl(result, `=`, u, v)
|
|
||||||
|
|
||||||
func `*`*[T](x, y: UintImpl[T]): UintImpl[T] {.inline.}=
|
|
||||||
## Multiplication for multi-precision unsigned uint
|
|
||||||
#
|
|
||||||
# For our representation, it is similar to school grade multiplication
|
|
||||||
# Consider hi and lo as if they were digits
|
|
||||||
#
|
|
||||||
# 12
|
|
||||||
# X 15
|
|
||||||
# ------
|
|
||||||
# 10 lo*lo -> z0
|
|
||||||
# 5 hi*lo -> z1
|
|
||||||
# 2 lo*hi -> z1
|
|
||||||
# 10 hi*hi -- z2
|
|
||||||
# ------
|
|
||||||
# 180
|
|
||||||
#
|
|
||||||
# If T is a type
|
|
||||||
# For T * T --> T we don't need to compute z2 as it always overflow
|
|
||||||
# For T * T --> 2T (uint64 * uint64 --> uint128) we use extra precision multiplication
|
|
||||||
|
|
||||||
extPrecMul(result, x.lo, y.lo)
|
|
||||||
result.hi += x.lo * y.hi + x.hi * y.lo
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user