[WIP] Division and multiplication optimization (#21)

* Clean-up code, part 1

* Managed to get best borrow code for the not inlined substraction #10

* Implement in place substraction in terms of substraction #10

* Another unneed proc removal/temporary step

* more cleanup

* Upgrade benchmark to Uint256

* Special case when divisor is less than halfSize x2 speed 🔥 (still 4x slower than ttmath on Uint256)

* Division: special case if dividend can overflow. 10% improvement.

* forgot to undo normalization (why did the test pass :??)

* 1st part, special cases of fast division

* Change bitops, simplify bithacks to detect new fast division cases

* 25% speed increase. Within 3x of ttmath

* Reimplement multiplication with minimum allocation

* Fix call. Now only 2x slower than ttmath

* Prepare for optimizing comparison operators

* Comparison inlining and optimization. 25% speed increase. 50% slower than ttmath now 🔥

* Fix comparison, optimize one()

* inline initMpUintImpl for another 20% speed. Only 20% slower than ttmath without ASM
This commit is contained in:
Mamy Ratsimbazafy 2018-04-21 12:12:05 +02:00 committed by GitHub
parent 7a5fc76561
commit 1749e0e575
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 517 additions and 427 deletions

View File

@ -18,10 +18,10 @@ echo "Warmup: " & $(stop - start) & "s"
start = cpuTime()
block:
var foo = 123.initMpUint(128)
var foo = 123.initMpUint(256)
for i in 0 ..< 10_000_000:
foo += i.initMpUint(128) * i.initMpUint(128) mod 456.initMpUint(128)
foo = foo mod 789.initMpUint(128)
foo += i.initMpUint(256) * i.initMpUint(256) mod 456.initMpUint(256)
foo = foo mod 789.initMpUint(256)
stop = cpuTime()
echo "Library: " & $(stop - start) & "s"

View File

@ -11,7 +11,7 @@
import
strutils,
../private/[uint_type, size_mpuintimpl]
../private/[uint_type, getSize]
func tohexBE*[T: uint8 or uint16 or uint32 or uint64](x: T): string =
## Stringify an uint to hex, Most significant byte on the left
@ -31,7 +31,7 @@ func tohexBE*(x: MpUintImpl): string =
## Stringify an uint to hex, Most significant byte on the left
## i.e. a (2.uint128)^64 + 1 will be 0000000100000001
const size = size_mpuintimpl(x) div 8
const size = getSize(x) div 8
let bytes = cast[ptr array[size, byte]](x.unsafeaddr)

View File

@ -8,5 +8,4 @@
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import ./uint_public, ./uint_init
export uint_public, uint_init

View File

@ -7,61 +7,24 @@
#
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import ./uint_type, stdlib_bitops, size_mpuintimpl
import ./uint_type, stdlib_bitops
export stdlib_bitops
# We reuse bitops from Nim standard lib and optimize it further on x86.
# On x86 clz it is implemented as bitscanreverse then xor and we need to again xor/sub.
# We need the bsr instructions so we xor again hoping for the compiler to only keep 1.
# We reuse bitops from Nim standard lib, and expand it for multi-precision int.
# MpInt rely on no undefined behaviour as often we scan 0. (if 1 is stored in a uint128 for example)
# Also countLeadingZeroBits must return the size of the type and not 0 like in the stdlib
proc bit_length*(x: SomeInteger): int {.noSideEffect.}=
when nimvm:
when sizeof(x) <= 4: result = if x == 0: 0 else: fastlog2_nim(x.uint32)
else: result = if x == 0: 0 else: fastlog2_nim(x.uint64)
else:
when useGCC_builtins:
when sizeof(x) <= 4: result = if x == 0: 0 else: builtin_clz(x.uint32) xor 31.cint
else: result = if x == 0: 0 else: builtin_clzll(x.uint64) xor 63.cint
elif useVCC_builtins:
when sizeof(x) <= 4:
result = if x == 0: 0 else: vcc_scan_impl(bitScanReverse, x.culong)
elif arch64:
result = if x == 0: 0 else: vcc_scan_impl(bitScanReverse64, x.uint64)
else:
result = if x == 0: 0 else: fastlog2_nim(x.uint64)
elif useICC_builtins:
when sizeof(x) <= 4:
result = if x == 0: 0 else: icc_scan_impl(bitScanReverse, x.uint32)
elif arch64:
result = if x == 0: 0 else: icc_scan_impl(bitScanReverse64, x.uint64)
else:
result = if x == 0: 0 else: fastlog2_nim(x.uint64)
else:
when sizeof(x) <= 4:
result = if x == 0: 0 else: fastlog2_nim(x.uint32)
else:
result = if x == 0: 0 else: fastlog2_nim(x.uint64)
proc bit_length*(n: MpUintImpl): int {.noSideEffect.}=
## Calculates how many bits are necessary to represent the number
const maxHalfRepr = n.lo.type.sizeof * 8 - 1
# Changing the following to an if expression somehow transform the whole ASM to 5 branches
# instead of the 4 expected (with the inline ASM from bit_length_impl)
# Also there does not seems to be a way to generate a conditional mov
let hi_bitlen = n.hi.bit_length
result = if hi_bitlen == 0: n.lo.bit_length
else: hi_bitlen + maxHalfRepr
proc countLeadingZeroBits*(x: MpUintImpl): int {.inline, nosideeffect.} =
func countLeadingZeroBits*(n: MpUintImpl): int {.inline.} =
## Returns the number of leading zero bits in integer.
const maxHalfRepr = size_mpuintimpl(x) div 2
const maxHalfRepr = getSize(n) div 2
let hi_clz = x.hi.countLeadingZeroBits
let hi_clz = n.hi.countLeadingZeroBits
result = if hi_clz == maxHalfRepr:
x.lo.countLeadingZeroBits + maxHalfRepr
n.lo.countLeadingZeroBits + maxHalfRepr
else: hi_clz
func bit_length*(n: SomeInteger): int {.inline.}=
## Calculates how many bits are necessary to represent the number
result = getSize(n) - n.countLeadingZeroBits

View File

@ -7,14 +7,14 @@
#
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import ./uint_type, ./size_mpuintimpl,
import ./uint_type,
macros
proc initMpUintImpl*[InType, OutType](x: InType, _: typedesc[OutType]): OutType {.noSideEffect.} =
func initMpUintImpl*[InType, OutType](x: InType, _: typedesc[OutType]): OutType {.inline.} =
const
size_in = size_mpuintimpl(x)
size_out = size_mpuintimpl(result)
size_in = getSize(x)
size_out = getSize(result)
static:
assert size_out >= size_in, "The result type size should be equal or bigger than the input type size"
@ -26,23 +26,27 @@ proc initMpUintImpl*[InType, OutType](x: InType, _: typedesc[OutType]): OutType
else:
result.lo = initMpUintImpl(x, type result.lo)
proc toSubtype*[T: SomeInteger](b: bool, _: typedesc[T]): T {.noSideEffect, inline.}=
func toSubtype*[T: SomeInteger](b: bool, _: typedesc[T]): T {.inline.}=
b.T
proc toSubtype*[T: MpUintImpl](b: bool, _: typedesc[T]): T {.noSideEffect, inline.}=
func toSubtype*[T: MpUintImpl](b: bool, _: typedesc[T]): T {.inline.}=
type SubTy = type result.lo
result.lo = toSubtype(b, SubTy)
proc zero*[T: BaseUint](_: typedesc[T]): T {.noSideEffect, inline.}=
func zero*[T: BaseUint](_: typedesc[T]): T {.inline.}=
discard
proc one*[T: BaseUint](_: typedesc[T]): T {.noSideEffect, inline.}=
func one*[T: BaseUint](_: typedesc[T]): T {.inline.}=
when T is SomeUnsignedInt:
result = T(1)
else:
result.lo = one(type result.lo)
let r_ptr = cast[ptr array[getSize(result) div 8, byte]](result.addr)
when system.cpuEndian == bigEndian:
r_ptr[0] = 1
else:
r_ptr[r_ptr[].len - 1] = 1
proc toUint*(n: MpUIntImpl): auto {.noSideEffect, inline.}=
func toUint*(n: MpUIntImpl): auto {.inline.}=
## Casts a multiprecision integer to an uint of the same size
# TODO: uint128 support
@ -57,11 +61,11 @@ proc toUint*(n: MpUIntImpl): auto {.noSideEffect, inline.}=
else:
raise newException("Unreachable. MpUInt must be 16-bit minimum and a power of 2")
proc toUint*(n: SomeUnsignedInt): SomeUnsignedInt {.noSideEffect, inline.}=
func toUint*(n: SomeUnsignedInt): SomeUnsignedInt {.inline.}=
## No-op overload of multi-precision int casting
n
proc asDoubleUint*(n: BaseUint): auto {.noSideEffect, inline.} =
func asDoubleUint*(n: BaseUint): auto {.inline.} =
## Convert an integer or MpUint to an uint with double the size
type Double = (
@ -73,7 +77,7 @@ proc asDoubleUint*(n: BaseUint): auto {.noSideEffect, inline.} =
n.toUint.Double
proc toMpUintImpl*(n: uint16|uint32|uint64): auto {.noSideEffect, inline.} =
func toMpUintImpl*(n: uint16|uint32|uint64): auto {.inline.} =
## Cast an integer to the corresponding size MpUintImpl
# Sometimes direct casting doesn't work and we must cast through a pointer
@ -84,6 +88,6 @@ proc toMpUintImpl*(n: uint16|uint32|uint64): auto {.noSideEffect, inline.} =
elif n is uint16:
return (cast[ptr [MpUintImpl[uint8]]](unsafeAddr n))[]
proc toMpUintImpl*(n: MpUintImpl): MpUintImpl {.noSideEffect, inline.} =
func toMpUintImpl*(n: MpUintImpl): MpUintImpl {.inline.} =
## No op
n

View File

@ -1,12 +0,0 @@
# Mpint
# Copyright 2018 Status Research & Development GmbH
# Licensed under either of
#
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0)
# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT)
#
# at your option. This file may not be copied, modified, or distributed except according to those terms.
proc divmod*(x, y: SomeInteger): tuple[quot, rem: SomeInteger] {.noSideEffect, inline.}=
# hopefully the compiler fuse that in a single op
(x div y, x mod y)

View File

@ -1,36 +0,0 @@
# Copyright 2018 Status Research & Development GmbH
# Licensed under either of
#
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0)
# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT)
#
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import ./uint_type, macros
proc size_mpuintimpl*(x: NimNode): static[int] =
# Size of doesn't always work at compile-time, pending PR https://github.com/nim-lang/Nim/pull/5664
var multiplier = 1
var node = x.getTypeInst
while node.kind == nnkBracketExpr:
assert eqIdent(node[0], "MpuintImpl")
multiplier *= 2
node = node[1]
# node[1] has the type
# size(node[1]) * multiplier is the size in byte
# For optimization we cast to the biggest possible uint
result = if eqIdent(node, "uint64"): multiplier * 64
elif eqIdent(node, "uint32"): multiplier * 32
elif eqIdent(node, "uint16"): multiplier * 16
else: multiplier * 8
macro size_mpuintimpl*(x: typed): untyped =
let size = size_mpuintimpl(x)
result = quote do:
`size`

View File

@ -24,6 +24,13 @@
## At this time only `fastLog2`, `firstSetBit, `countLeadingZeroBits`, `countTrailingZeroBits`
## may return undefined and/or platform dependant value if given invalid input.
# Bitops from the standard lib modified for MpInt use.
# - No undefined behaviour or flag needed
# - Note that for CountLeadingZero, it returns sizeof(input) * 8
# instead of 0
const useBuiltins* = not defined(noIntrinsicsBitOpts)
# const noUndefined* = defined(noUndefinedBitOpts)
const useGCC_builtins* = (defined(gcc) or defined(llvm_gcc) or defined(clang)) and useBuiltins
@ -32,7 +39,7 @@ const useVCC_builtins* = defined(vcc) and useBuiltins
const arch64* = sizeof(int) == 8
proc fastlog2_nim*(x: uint32): int {.inline, nosideeffect.} =
func fastlog2_nim(x: uint32): int {.inline.} =
## Quickly find the log base 2 of a 32-bit or less integer.
# https://graphics.stanford.edu/%7Eseander/bithacks.html#IntegerLogDeBruijn
# https://stackoverflow.com/questions/11376288/fast-computing-of-log2-for-64-bit-integers
@ -46,7 +53,7 @@ proc fastlog2_nim*(x: uint32): int {.inline, nosideeffect.} =
v = v or v shr 16
result = lookup[uint32(v * 0x07C4ACDD'u32) shr 27].int
proc fastlog2_nim*(x: uint64): int {.inline, nosideeffect.} =
func fastlog2_nim(x: uint64): int {.inline.} =
## Quickly find the log base 2 of a 64-bit integer.
# https://graphics.stanford.edu/%7Eseander/bithacks.html#IntegerLogDeBruijn
# https://stackoverflow.com/questions/11376288/fast-computing-of-log2-for-64-bit-integers
@ -89,15 +96,14 @@ elif useICC_builtins:
discard fnc(index.addr, v)
index.int
proc countLeadingZeroBits*(x: SomeInteger): int {.inline, nosideeffect.} =
func countLeadingZeroBits*(x: SomeInteger): int {.inline.} =
## Returns the number of leading zero bits in integer.
## If `x` is zero, when ``noUndefinedBitOpts`` is set, result is 0,
## otherwise result is undefined.
# when noUndefined:
if x == 0:
return sizeof(x) * 8
return sizeof(x) * 8 # Note this differes from the stdlib which returns 0
when nimvm:
when sizeof(x) <= 4: result = sizeof(x)*8 - 1 - fastlog2_nim(x.uint32)

View File

@ -0,0 +1,38 @@
# Mpint
# Copyright 2018 Status Research & Development GmbH
# Licensed under either of
#
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0)
# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT)
#
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import ./bithacks, ./conversion,
./uint_type,
./uint_comparison,
./uint_bitwise_ops
# ############ Addition & Substraction ############ #
proc `+=`*(x: var MpUintImpl, y: MpUintImpl) {.noSideEffect, inline.}=
## In-place addition for multi-precision unsigned int
type SubTy = type x.lo
x.lo += y.lo
x.hi += (x.lo < y.lo).toSubtype(SubTy) + y.hi
proc `+`*(x, y: MpUintImpl): MpUintImpl {.noSideEffect, noInit, inline.}=
# Addition for multi-precision unsigned int
result = x
result += y
proc `-`*(x, y: MpUintImpl): MpUintImpl {.noSideEffect, noInit, inline.}=
# Substraction for multi-precision unsigned int
type SubTy = type x.lo
result.lo = x.lo - y.lo
result.hi = x.hi - y.hi - (x.lo < y.lo).toSubtype(SubTy)
proc `-=`*(x: var MpUintImpl, y: MpUintImpl) {.noSideEffect, inline.}=
## In-place substraction for multi-precision unsigned int
x = x - y

View File

@ -1,118 +0,0 @@
# Mpint
# Copyright 2018 Status Research & Development GmbH
# Licensed under either of
#
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0)
# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT)
#
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import ./bithacks, ./conversion,
./uint_type,
./uint_comparison,
./uint_bitwise_ops,
./size_mpuintimpl
# ############ Addition & Substraction ############ #
proc `+=`*(x: var MpUintImpl, y: MpUintImpl) {.noSideEffect, inline.}=
## In-place addition for multi-precision unsigned int
#
# Optimized assembly should contain adc instruction (add with carry)
# Clang on MacOS does with the -d:release switch and MpUint[uint32] (uint64)
type SubTy = type x.lo
x.lo += y.lo
x.hi += (x.lo < y.lo).toSubtype(SubTy) + y.hi
proc `+`*(x, y: MpUintImpl): MpUintImpl {.noSideEffect, noInit, inline.}=
# Addition for multi-precision unsigned int
result = x
result += y
proc `-=`*(x: var MpUintImpl, y: MpUintImpl) {.noSideEffect, inline.}=
## In-place substraction for multi-precision unsigned int
#
# Optimized assembly should contain sbb instruction (substract with borrow)
# Clang on MacOS does with the -d:release switch and MpUint[uint32] (uint64)
type SubTy = type x.lo
x.hi -= (x.lo < y.lo).toSubtype(SubTy) + y.hi
x.lo -= y.lo
proc `-`*(x, y: MpUintImpl): MpUintImpl {.noSideEffect, noInit, inline.}=
# Substraction for multi-precision unsigned int
result = x
result -= y
# ################### Multiplication ################### #
proc naiveMulImpl[T: MpUintImpl](x, y: T): MpUintImpl[T] {.noSideEffect, noInit, inline.}
# Forward declaration
proc naiveMul*[T: BaseUint](x, y: T): MpUintImpl[T] {.noSideEffect, noInit, inline.}=
## Naive multiplication algorithm with extended precision
const size = size_mpuintimpl(x)
when size in {8, 16, 32}:
# Use types twice bigger to do the multiplication
cast[type result](x.asDoubleUint * y.asDoubleUint)
elif size == 64: # uint64 or MpUint[uint32]
# We cannot double uint64 to uint128
cast[type result](naiveMulImpl(x.toMpUintImpl, y.toMpUintImpl))
else:
# Case: at least uint128 * uint128 --> uint256
cast[type result](naiveMulImpl(x, y))
proc naiveMulImpl[T: MpUintImpl](x, y: T): MpUintImpl[T] {.noSideEffect, noInit, inline.}=
# See details at
# https://en.wikipedia.org/wiki/Karatsuba_algorithm
# https://locklessinc.com/articles/256bit_arithmetic/
# https://www.miracl.com/press/missing-a-trick-karatsuba-variations-michael-scott
#
# We use the naive school grade multiplication instead of Karatsuba I.e.
# z1 = x.hi * y.lo + x.lo * y.hi (Naive) = (x.lo - x.hi)(y.hi - y.lo) + z0 + z2 (Karatsuba)
#
# On modern architecture:
# - addition and multiplication have the same cost
# - Karatsuba would require to deal with potentially negative intermediate result
# and introduce branching
# - More total operations means more register moves
const halfSize = size_mpuintimpl(x) div 2
let
z0 = naiveMul(x.lo, y.lo)
tmp = naiveMul(x.hi, y.lo)
var z1 = tmp
z1 += naiveMul(x.hi, y.lo)
let z2 = (z1 < tmp).toSubtype(T) + naiveMul(x.hi, y.hi)
let tmp2 = initMpUintImpl(z1.lo shl halfSize, T)
result.lo = tmp2
result.lo += z0
result.hi = (result.lo < tmp2).toSubtype(T) + z2 + initMpUintImpl(z1.hi, type result.hi)
proc `*`*(x, y: MpUintImpl): MpUintImpl {.noSideEffect, noInit.}=
## Multiplication for multi-precision unsigned uint
#
# For our representation, it is similar to school grade multiplication
# Consider hi and lo as if they were digits
#
# 12
# X 15
# ------
# 10 lo*lo -> z0
# 5 hi*lo -> z1
# 2 lo*hi -> z1
# 10 hi*hi -- z2
# ------
# 180
#
# If T is a type
# For T * T --> T we don't need to compute z2 as it always overflow
# For T * T --> 2T (uint64 * uint64 --> uint128) we use extra precision multiplication
result = naiveMul(x.lo, y.lo)
result.hi += (naiveMul(x.hi, y.lo) + naiveMul(x.lo, y.hi)).lo

View File

@ -7,7 +7,7 @@
#
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import ./uint_type, ./size_mpuintimpl, ./conversion
import ./uint_type, ./conversion
func `not`*(x: MpUintImpl): MpUintImpl {.noInit, inline.}=
@ -40,7 +40,7 @@ func `shl`*(x: MpUintImpl, y: SomeInteger): MpUintImpl {.inline.}=
# TODO: would it be better to reimplement this using an array of bytes/uint64
# That opens up to endianness issues.
const halfSize = size_mpuintimpl(x) div 2
const halfSize = getSize(x) div 2
let defect = halfSize - int(y)
if y == 0:
@ -55,7 +55,7 @@ func `shl`*(x: MpUintImpl, y: SomeInteger): MpUintImpl {.inline.}=
func `shr`*(x: MpUintImpl, y: SomeInteger): MpUintImpl {.inline.}=
## Compute the `shift right` operation of x and y
const halfSize = size_mpuintimpl(x) div 2
const halfSize = getSize(x) div 2
if y == 0:
return x

View File

@ -7,45 +7,94 @@
#
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import ./uint_type, ./size_mpuintimpl, macros
import ./uint_type, macros
macro cast_optim(x: typed): untyped =
let size = size_mpuintimpl(x)
macro optim(x: typed): untyped =
let size = getSize(x)
if size > 64:
result = quote do:
cast[array[`size` div 64, uint64]](`x`)
array[`size` div 64, uint64]
elif size == 64:
result = quote do:
cast[uint64](`x`)
uint64
elif size == 32:
result = quote do:
cast[uint32](`x`)
uint32
elif size == 16:
result = quote do:
cast[uint16](`x`)
uint16
elif size == 8:
result = quote do:
cast[uint8](`x`)
uint8
else:
error "Unreachable path reached"
proc isZero*(n: SomeUnsignedInt): bool {.noSideEffect,inline.} =
func isZero*(n: SomeUnsignedInt): bool {.inline.} =
n == 0
proc isZero*(n: MpUintImpl): bool {.noSideEffect,inline.} =
n == (type n)()
func isZero*(n: MpUintImpl): bool {.inline.} =
proc `<`*(x, y: MpUintImpl): bool {.noSideEffect, noInit, inline.}=
(x.hi < y.hi) or ((x.hi == y.hi) and x.lo < y.lo)
when optim(`n`) is array:
for val in cast[optim(n)](n):
if val != 0:
return false
return true
else:
cast[optim(n)](n) == 0
proc `==`*(x, y: MpuintImpl): bool {.noSideEffect, noInit, inline.}=
func `<`*(x, y: MpUintImpl): bool {.noInit, inline.}=
when optim(x) is array:
let
x_ptr = cast[ptr optim(x)](x.unsafeaddr)
y_ptr = cast[ptr optim(y)](y.unsafeaddr)
when system.cpuEndian == bigEndian:
for i in 0..<x_ptr[].len:
if x_ptr[i] != y_ptr[i]:
return x_ptr[i] < y_ptr[i]
return false # They're equal
else: # littleEndian, the most significant bytes are on the right
for i in countdown(x_ptr[].len - 1, 0):
if x_ptr[i] != y_ptr[i]:
return x_ptr[i] < y_ptr[i]
return false # They're equal
else:
cast[optim(x)](x) < cast[optim(y)](y)
func `==`*(x, y: MpUintImpl): bool {.noInit, inline.}=
# Equal comparison for multi-precision integers
# We cast to array of uint64 because the default comparison is slow
result = cast_optim(x) == cast_optim(y)
when optim(x) is array:
let
x_ptr = cast[ptr optim(x)](x.unsafeaddr)
y_ptr = cast[ptr optim(y)](y.unsafeaddr)
proc `<=`*(x, y: MpUintImpl): bool {.noSideEffect, noInit, inline.}=
for i in 0..<x_ptr[].len:
if x_ptr[i] != y_ptr[i]:
return false
return true
else:
cast[optim(x)](x) < cast[optim(y)](y)
func `<=`*(x, y: MpUintImpl): bool {.noInit, inline.}=
# Lower or equal comparison for multi-precision integers
result = if x == y: true
else: x < y
when optim(x) is array:
let
x_ptr = cast[ptr optim(x)](x.unsafeaddr)
y_ptr = cast[ptr optim(y)](y.unsafeaddr)
when system.cpuEndian == bigEndian:
for i in 0..<x_ptr[].len:
if x_ptr[i] != y_ptr[i]:
return x_ptr[i] < y_ptr[i]
return true # They're equal
else: # littleEndian, the most significant bytes are on the right
for i in countdown(x_ptr[].len - 1, 0):
if x_ptr[i] != y_ptr[i]:
return x_ptr[i] < y_ptr[i]
return true # They're equal
else:
cast[optim(x)](x) <= cast[optim(y)](y)

View File

@ -11,9 +11,8 @@ import ./bithacks, ./conversion,
./uint_type,
./uint_comparison,
./uint_bitwise_ops,
./uint_binary_ops,
./size_mpuintimpl,
./primitive_divmod
./uint_addsub,
./uint_mul
# ################### Division ################### #
# We use the following algorithm:
@ -45,142 +44,96 @@ import ./bithacks, ./conversion,
## ##
###################################################################################################################
func div2n1n[T: SomeunsignedInt](q, r: var T, n_hi, n_lo, d: T) {.inline.}
func div2n1n(q, r: var MpUintImpl, ah, al, b: MpUintImpl) {.inline.}
func div2n1n[T: SomeunsignedInt](q, r: var T, n_hi, n_lo, d: T)
func div2n1n(q, r: var MpUintImpl, ah, al, b: MpUintImpl)
# Forward declaration
func div3n2n[T]( q, r1, r0: var MpUintImpl[T],
a2, a1, a0: MpUintImpl[T],
b1, b0: MpUintImpl[T]) {.inline.}=
mixin div2n1n
proc divmod*(x, y: SomeInteger): tuple[quot, rem: SomeInteger] {.noSideEffect, inline.}=
# hopefully the compiler fuse that in a single op
(x div y, x mod y)
type T = type q
func divmod*[T](x, y: MpUintImpl[T]): tuple[quot, rem: MpUintImpl[T]]
# Forward declaration
func div3n2n[T]( q: var MpUintImpl[T],
r: var MpUintImpl[MpUintImpl[T]],
a2, a1, a0: MpUintImpl[T],
b: MpUintImpl[MpUintImpl[T]]) =
var
c: T
c: MpUintImpl[T]
d: MpUintImpl[MpUintImpl[T]]
carry: bool
if a2 < b1:
div2n1n(q, c, a2, a1, b1)
if a2 < b.hi:
div2n1n(q, c, a2, a1, b.hi)
else:
q = zero(type q) - one(type q) # We want 0xFFFFF ....
c = a1 + b1
c = a1 + b.hi
if c < a1:
carry = true
let
d = naiveMul(q, b0)
b = MpUintImpl[type c](hi: b1, lo: b0)
extPrecMul[T](d, q, b.lo)
let ca0 = MpUintImpl[type c](hi: c, lo: a0)
var r = MpUintImpl[type c](hi: c, lo: a0) - d
r = ca0 - d
if (not carry) and (d > r):
if (not carry) and (d > ca0):
q -= one(type q)
r += b
# if there was no carry
if r > b:
q -= one(type q)
r += b
r1 = r.hi
r0 = r.lo
template sub_ddmmss[T](sh, sl, ah, al, bh, bl: T) =
sl = al - bl
sh = ah - bh - (al < bl).T
func lo[T:SomeUnsignedInt](x: T): T {.inline.} =
const
p = T.sizeof * 8 div 2
base = 1 shl p
mask = base - 1
result = x and mask
func hi[T:SomeUnsignedInt](x: T): T {.inline.} =
const
p = T.sizeof * 8 div 2
result = x shr p
func umul_ppmm[T](w1, w0: var T, u, v: T) =
const
p = (T.sizeof * 8 div 2)
base = 1 shl p
proc div3n2n[T: SomeUnsignedInt](
q: var T,
r: var MpUintImpl[T],
a2, a1, a0: T,
b: MpUintImpl[T]) =
var
x0, x1, x2, x3: T
let
ul = u.lo
uh = u.hi
vl = v.lo
vh = v.hi
x0 = ul * vl
x1 = ul * vh
x2 = uh * vl
x3 = uh * vh
x1 += x0.hi # This can't carry
x1 += x2 # but this can
if x1 < x2: # if carry, add it to x3
x3 += base
w1 = x3 + x1.hi
w0 = (x1 shl p) + x0.lo
proc div3n2n( q, r1, r0: var SomeUnsignedInt,
a2, a1, a0: SomeUnsignedInt,
b1, b0: SomeUnsignedInt) {.inline.}=
mixin div2n1n
type T = type q
var
c, d1, d0: T
c: T
d: MpUintImpl[T]
carry: bool
if a2 < b1:
div2n1n(q, c, a2, a1, b1)
if a2 < b.hi:
div2n1n(q, c, a2, a1, b.hi)
else:
q = 0.T - 1.T # We want 0xFFFFF ....
c = a1 + b1
c = a1 + b.hi
if c < a1:
carry = true
umul_ppmm(d1, d0, q, b0)
sub_ddmmss(r1, r0, c, a0, d1, d0)
extPrecMul[T](d, q, b.lo)
let ca0 = MpUintImpl[T](hi: c, lo: a0)
r = ca0 - d
if (not carry) and ((d1 > c) or ((d1 == c) and (d0 > a0))):
q -= 1.T
r0 += b0
r1 += b1
if r0 < b0:
inc r1
if (not carry) and d > ca0:
dec q
r += b
if (r1 > b1) or ((r1 == b1) and (r0 >= b0)):
q -= 1.T
r0 += b0
r1 += b1
if r0 < b0:
inc r1
# if there was no carry
if r > b:
dec q
r += b
func div2n1n(q, r: var MpUintImpl, ah, al, b: MpUintImpl) {.inline.} =
func div2n1n(q, r: var MpUintImpl, ah, al, b: MpUintImpl) =
# assert countLeadingZeroBits(b) == 0, "Divisor was not normalized"
var s: MpUintImpl
div3n2n(q.hi, s.hi, s.lo, ah.hi, ah.lo, al.hi, b.hi, b.lo)
div3n2n(q.lo, r.hi, r.lo, s.hi, s.lo, al.lo, b.hi, b.lo)
div3n2n(q.hi, s, ah.hi, ah.lo, al.hi, b)
div3n2n(q.lo, r, s.hi, s.lo, al.lo, b)
func div2n1n[T: SomeunsignedInt](q, r: var T, n_hi, n_lo, d: T) {.inline.} =
func div2n1n[T: SomeunsignedInt](q, r: var T, n_hi, n_lo, d: T) =
# assert countLeadingZeroBits(d) == 0, "Divisor was not normalized"
const
size = size_mpuintimpl(q)
size = getSize(q)
halfSize = size div 2
halfMask = (1.T shl halfSize) - 1.T
@ -192,10 +145,10 @@ func div2n1n[T: SomeunsignedInt](q, r: var T, n_hi, n_lo, d: T) {.inline.} =
# Fix the reminder, we're at most 2 iterations off
if r < m:
q -= 1.T
dec q
r += d_hi
if r >= d_hi and r < m:
q -= 1.T
dec q
r += d_hi
r -= m
(q, r)
@ -215,23 +168,114 @@ func div2n1n[T: SomeunsignedInt](q, r: var T, n_hi, n_lo, d: T) {.inline.} =
q = (q1 shl halfSize) or q2
r = r2
func divmod*[T](x, y: MpUintImpl[T]): tuple[quot, rem: MpUintImpl[T]] =
func divmodBZ[T](x, y: MpUintImpl[T], q, r: var MpUintImpl[T])=
# Normalization
assert y.isZero.not()
assert y.isZero.not() # This should be checked on release mode in the divmod caller proc
const halfSize = size_mpuintimpl(x) div 2
let clz = countLeadingZeroBits(y)
if y.hi.isZero:
# Shortcut if divisor is smaller than half the size of the type
let
xx = MpUintImpl[type x](lo: x) shl clz
yy = y shl clz
# Normalize
let
clz = countLeadingZeroBits(y.lo)
xx = x shl clz
yy = y.lo shl clz
# Compute
div2n1n(result.quot, result.rem, xx.hi, xx.lo, yy)
if x.hi < y.lo:
# If y is smaller than the base, normalizing x does not overflow.
# Compute directly
div2n1n(q.lo, r.lo, xx.hi, xx.lo, yy)
# Undo normalization
r.lo = r.lo shr clz
else:
# Normalizing x overflowed, we need to compute the high remainder first
(q.hi, r.hi) = divmod(x.hi, y.lo)
# Undo normalization
result.rem = result.rem shr clz
# Normalize the remainder. (x.lo is already normalized)
r.hi = r.hi shl clz
# Compute
div2n1n(q.lo, r.lo, r.hi, xx.lo, yy)
# Undo normalization
r.lo = r.lo shr clz
# Given size n, dividing a 2n number by a 1n normalized number
# always gives a 1n remainder.
r.hi = zero(T)
else: # General case
# Normalization
let clz = countLeadingZeroBits(y)
let
xx = MpUintImpl[type x](lo: x) shl clz
yy = y shl clz
# Compute
div2n1n(q, r, xx.hi, xx.lo, yy)
# Undo normalization
r = r shr clz
func divmodBS(x, y: MpUintImpl, q, r: var MpuintImpl) =
## Division for multi-precision unsigned uint
## Implementation through binary shift division
assert y.isZero.not() # This should be checked on release mode in the divmod caller proc
type SubTy = type x.lo
var
shift = x.countLeadingZeroBits - y.countLeadingZeroBits
d = y shl shift
r = x
while shift >= 0:
q += q
if r >= d:
r -= d
q.lo = q.lo or one(SubTy)
d = d shr 1
dec(shift)
const BinaryShiftThreshold = 8 # If the difference in bit-length is below 8
# binary shift is probably faster
func divmod*[T](x, y: MpUintImpl[T]): tuple[quot, rem: MpUintImpl[T]]=
let x_clz = x.countLeadingZeroBits
let y_clz = y.countLeadingZeroBits
# We short-circuit division depending on special-cases.
# TODO: Constant-time division
if unlikely(y.isZero):
raise newException(DivByZeroError, "You attempted to divide by zero")
elif y_clz == (getSize(y) - 1):
# y is one
result.quot = x
elif (x.hi or y.hi).isZero:
# If computing just on the low part is enough
(result.quot.lo, result.rem.lo) = divmod(x.lo, y.lo)
elif (y and (y - one(type y))).isZero:
# y is a power of 2. (this also matches 0 but it was eliminated earlier)
# TODO. Would it be faster to use countTrailingZero (ctz) + clz == size(y) - 1?
# Especially because we shift by ctz after.
# It is a bit tricky with recursive types. An empty n.lo means 0 or sizeof(n.lo)
let y_ctz = getSize(y) - y_clz - 1
result.quot = x shr y_ctz
result.rem = y_ctz.initMpUintImpl(MpUintImpl[T])
result.rem = result.rem and x
elif x == y:
result.quot.lo = one(T)
elif x < y:
result.rem = x
elif (y_clz - x_clz) < BinaryShiftThreshold:
divmodBS(x, y, result.quot, result.rem)
else:
divmodBZ(x, y, result.quot, result.rem)
func `div`*(x, y: MpUintImpl): MpUintImpl {.inline.} =
## Division operation for multi-precision unsigned uint
@ -280,31 +324,3 @@ func `mod`*(x, y: MpUintImpl): MpUintImpl {.inline.} =
# - Google Abseil: https://github.com/abseil/abseil-cpp/tree/master/absl/numeric
# - Crypto libraries like libsecp256k1, OpenSSL, ... though they are not generics. (uint256 only for example)
# Note: GMP/MPFR are GPL. The papers can be used but not their code.
# ######################################################################
# School division
# proc divmod*(x, y: MpUintImpl): tuple[quot, rem: MpUintImpl] {.noSideEffect.}=
# ## Division for multi-precision unsigned uint
# ## Returns quotient + reminder in a (quot, rem) tuple
# #
# # Implementation through binary shift division
# if unlikely(y.isZero):
# raise newException(DivByZeroError, "You attempted to divide by zero")
# type SubTy = type x.lo
# var
# shift = x.bit_length - y.bit_length
# d = y shl shift
# result.rem = x
# while shift >= 0:
# result.quot += result.quot
# if result.rem >= d:
# result.rem -= d
# result.quot.lo = result.quot.lo or one(SubTy)
# d = d shr 1
# dec(shift)

138
src/private/uint_mul.nim Normal file
View File

@ -0,0 +1,138 @@
# Mpint
# Copyright 2018 Status Research & Development GmbH
# Licensed under either of
#
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0)
# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT)
#
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import ./conversion,
./uint_type,
./uint_comparison,
./uint_addsub
# ################### Multiplication ################### #
func lo[T:SomeUnsignedInt](x: T): T {.inline.} =
const
p = T.sizeof * 8 div 2
base = 1 shl p
mask = base - 1
result = x and mask
func hi[T:SomeUnsignedInt](x: T): T {.inline.} =
const
p = T.sizeof * 8 div 2
result = x shr p
# No generic, somehow Nim is given ambiguous call with the T: MpUintImpl overload
func extPrecMul*(result: var MpUintImpl[uint8], x, y: uint8) =
## Extended precision multiplication
result = cast[type result](x.asDoubleUint * y.asDoubleUint)
func extPrecMul*(result: var MpUintImpl[uint16], x, y: uint16) =
## Extended precision multiplication
result = cast[type result](x.asDoubleUint * y.asDoubleUint)
func extPrecMul*(result: var MpUintImpl[uint32], x, y: uint32) =
## Extended precision multiplication
result = cast[type result](x.asDoubleUint * y.asDoubleUint)
func extPrecAddMul[T: uint8 or uint16 or uint32](result: var MpUintImpl[T], x, y: T) =
## Extended precision fused in-place addition & multiplication
result += cast[type result](x.asDoubleUint * y.asDoubleUint)
template extPrecMulImpl(result: var MpUintImpl[uint64], op: untyped, u, v: uint64) =
const
p = 64 div 2
base = 1 shl p
var
x0, x1, x2, x3: uint64
let
ul = u.lo
uh = u.hi
vl = v.lo
vh = v.hi
x0 = ul * vl
x1 = ul * vh
x2 = uh * vl
x3 = uh * vh
x1 += x0.hi # This can't carry
x1 += x2 # but this can
if x1 < x2: # if carry, add it to x3
x3 += base
op(result.hi, x3 + x1.hi)
op(result.lo, (x1 shl p) or x0.lo)
func extPrecMul*(result: var MpUintImpl[uint64], u, v: uint64) =
## Extended precision multiplication
extPrecMulImpl(result, `=`, u, v)
func extPrecAddMul(result: var MpUintImpl[uint64], u, v: uint64) =
## Extended precision fused in-place addition & multiplication
extPrecMulImpl(result, `+=`, u, v)
func extPrecMul*[T](result: var MpUintImpl[MpUintImpl[T]], x, y: MpUintImpl[T]) =
# See details at
# https://en.wikipedia.org/wiki/Karatsuba_algorithm
# https://locklessinc.com/articles/256bit_arithmetic/
# https://www.miracl.com/press/missing-a-trick-karatsuba-variations-michael-scott
#
# We use the naive school grade multiplication instead of Karatsuba I.e.
# z1 = x.hi * y.lo + x.lo * y.hi (Naive) = (x.lo - x.hi)(y.hi - y.lo) + z0 + z2 (Karatsuba)
#
# On modern architecture:
# - addition and multiplication have the same cost
# - Karatsuba would require to deal with potentially negative intermediate result
# and introduce branching
# - More total operations means more register moves
var z1: MpUintImpl[T]
# Low part - z0
extPrecMul(result.lo, x.lo, y.lo)
# Middle part - z1
extPrecMul(z1, x.hi, y.lo)
let carry_check = z1
extPrecAddMul(z1, x.lo, y.hi)
if z1 < carry_check:
result.hi.lo = one(T)
# High part - z2
result.hi.lo += z1.hi
extPrecAddMul(result.hi, x.hi, y.hi)
# Finalize low part
result.lo.hi += z1.lo
if result.lo.hi < z1.lo:
result.hi += one(MpUintImpl[T])
func `*`*[T](x, y: MpUintImpl[T]): MpUintImpl[T] {.inline.}=
## Multiplication for multi-precision unsigned uint
#
# For our representation, it is similar to school grade multiplication
# Consider hi and lo as if they were digits
#
# 12
# X 15
# ------
# 10 lo*lo -> z0
# 5 hi*lo -> z1
# 2 lo*hi -> z1
# 10 hi*hi -- z2
# ------
# 180
#
# If T is a type
# For T * T --> T we don't need to compute z2 as it always overflow
# For T * T --> 2T (uint64 * uint64 --> uint128) we use extra precision multiplication
extPrecMul(result, x.lo, y.lo)
result.hi += x.lo * y.hi + x.hi * y.lo

View File

@ -51,6 +51,32 @@ else:
else:
error "Fatal: unreachable"
proc getSize*(x: NimNode): static[int] =
# Size of doesn't always work at compile-time, pending PR https://github.com/nim-lang/Nim/pull/5664
var multiplier = 1
var node = x.getTypeInst
while node.kind == nnkBracketExpr:
assert eqIdent(node[0], "MpuintImpl")
multiplier *= 2
node = node[1]
# node[1] has the type
# size(node[1]) * multiplier is the size in byte
# For optimization we cast to the biggest possible uint
result = if eqIdent(node, "uint64"): multiplier * 64
elif eqIdent(node, "uint32"): multiplier * 32
elif eqIdent(node, "uint16"): multiplier * 16
else: multiplier * 8
macro getSize*(x: typed): untyped =
let size = getSize(x)
result = quote do:
`size`
type
# ### Private ### #
# If this is not in the same type section

View File

@ -14,7 +14,7 @@ import ./private/bithacks, ./private/conversion,
import typetraits
proc initMpUint*[T: SomeInteger](n: T, bits: static[int]): MpUint[bits] {.noSideEffect.} =
func initMpUint*[T: SomeInteger](n: T, bits: static[int]): MpUint[bits] {.inline.}=
assert n >= 0.T
when result.data is MpuintImpl:
let len = n.bit_length

View File

@ -15,14 +15,14 @@ type
UInt256* = MpUint[256]
template make_conv(conv_name: untyped, size: int): untyped =
proc `convname`*(n: SomeInteger): MpUint[size] {.noSideEffect, inline, noInit.}=
func `convname`*(n: SomeInteger): MpUint[size] {.inline, noInit.}=
n.initMpUint(size)
make_conv(u128, 128)
make_conv(u256, 256)
template make_unary(op, ResultTy): untyped =
proc `op`*(x: MpUint): ResultTy {.noInit, inline, noSideEffect.} =
func `op`*(x: MpUint): ResultTy {.noInit, inline.} =
when resultTy is MpUint:
result.data = op(x.data)
else:
@ -30,7 +30,7 @@ template make_unary(op, ResultTy): untyped =
export op
template make_binary(op, ResultTy): untyped =
proc `op`*(x, y: MpUint): ResultTy {.noInit, inline, noSideEffect.} =
func `op`*(x, y: MpUint): ResultTy {.noInit, inline.} =
when ResultTy is MpUint:
result.data = op(x.data, y.data)
else:
@ -38,31 +38,33 @@ template make_binary(op, ResultTy): untyped =
export `op`
template make_binary_inplace(op): untyped =
proc `op`*(x: var MpUint, y: MpUint) {.inline, noSideEffect.} =
func `op`*(x: var MpUint, y: MpUint) {.inline.} =
op(x.data, y.data)
export op
import ./private/uint_binary_ops
import ./private/uint_addsub
make_binary(`+`, MpUint)
make_binary_inplace(`+=`)
make_binary(`-`, MpUint)
make_binary_inplace(`-=`)
import ./private/uint_mul
make_binary(`*`, MpUint)
import ./private/primitive_divmod,
./private/uint_division
import ./private/uint_div
make_binary(`div`, MpUint)
make_binary(`mod`, MpUint)
proc divmod*(x, y: MpUint): tuple[quot, rem: MpUint] {.noInit, inline, noSideEffect.} =
func divmod*(x, y: MpUint): tuple[quot, rem: MpUint] {.noInit, inline.} =
(result.quot.data, result.rem.data) = divmod(x.data, y.data)
import ./private/uint_comparison
make_binary(`<`, bool)
make_binary(`<=`, bool)
proc isZero*(x: MpUint): bool {.inline, noSideEffect.} = isZero x
make_binary(`==`, bool)
func isZero*(x: MpUint): bool {.inline.} = isZero x.data
import ./private/uint_bitwise_ops

View File

@ -14,27 +14,42 @@ suite "Testing comparison operators":
a = 10.initMpUint(16)
b = 15.initMpUint(16)
c = 150'u16
d = 4.initMpUint(128) shl 64
e = 4.initMpUint(128)
f = 4.initMpUint(128) shl 65
test "< operator":
check: a < b
check: not (a + b < b)
check: not (a + a + a < b + b)
check: not (a * b < cast[MpUint[16]](c))
check:
a < b
not (a + b < b)
not (a + a + a < b + b)
not (a * b < cast[MpUint[16]](c))
e < d
d < f
test "<= operator":
check: a <= b
check: not (a + b <= b)
check: a + a + a <= b + b
check: a * b <= cast[MpUint[16]](c)
check:
a <= b
not (a + b <= b)
a + a + a <= b + b
a * b <= cast[MpUint[16]](c)
e <= d
d <= f
test "> operator":
check: b > a
check: not (b > a + b)
check: not (b + b > a + a + a)
check: not (cast[Mpuint[16]](c) > a * b)
check:
b > a
not (b > a + b)
not (b + b > a + a + a)
not (cast[Mpuint[16]](c) > a * b)
d > e
f > d
test ">= operator":
check: b >= a
check: not (b >= a + b)
check: b + b >= a + a + a
check: cast[MpUint[16]](c) >= a * b
check:
b >= a
not (b >= a + b)
b + b >= a + a + a
cast[MpUint[16]](c) >= a * b
d >= e
f >= d