diff --git a/benchmarks/bench_mod.nim b/benchmarks/bench_mod.nim index 1bde734..e25d916 100644 --- a/benchmarks/bench_mod.nim +++ b/benchmarks/bench_mod.nim @@ -18,10 +18,10 @@ echo "Warmup: " & $(stop - start) & "s" start = cpuTime() block: - var foo = 123.initMpUint(128) + var foo = 123.initMpUint(256) for i in 0 ..< 10_000_000: - foo += i.initMpUint(128) * i.initMpUint(128) mod 456.initMpUint(128) - foo = foo mod 789.initMpUint(128) + foo += i.initMpUint(256) * i.initMpUint(256) mod 456.initMpUint(256) + foo = foo mod 789.initMpUint(256) stop = cpuTime() echo "Library: " & $(stop - start) & "s" diff --git a/src/debug/debugutils.nim b/src/debug/debugutils.nim index 2b3f1ef..76ae902 100644 --- a/src/debug/debugutils.nim +++ b/src/debug/debugutils.nim @@ -11,7 +11,7 @@ import strutils, - ../private/[uint_type, size_mpuintimpl] + ../private/[uint_type, getSize] func tohexBE*[T: uint8 or uint16 or uint32 or uint64](x: T): string = ## Stringify an uint to hex, Most significant byte on the left @@ -31,7 +31,7 @@ func tohexBE*(x: MpUintImpl): string = ## Stringify an uint to hex, Most significant byte on the left ## i.e. a (2.uint128)^64 + 1 will be 0000000100000001 - const size = size_mpuintimpl(x) div 8 + const size = getSize(x) div 8 let bytes = cast[ptr array[size, byte]](x.unsafeaddr) diff --git a/src/mpint.nim b/src/mpint.nim index 0a66265..b047a3b 100644 --- a/src/mpint.nim +++ b/src/mpint.nim @@ -8,5 +8,4 @@ # at your option. This file may not be copied, modified, or distributed except according to those terms. import ./uint_public, ./uint_init - export uint_public, uint_init diff --git a/src/private/bithacks.nim b/src/private/bithacks.nim index 7c80bb9..d9e034f 100644 --- a/src/private/bithacks.nim +++ b/src/private/bithacks.nim @@ -7,61 +7,24 @@ # # at your option. This file may not be copied, modified, or distributed except according to those terms. -import ./uint_type, stdlib_bitops, size_mpuintimpl +import ./uint_type, stdlib_bitops +export stdlib_bitops -# We reuse bitops from Nim standard lib and optimize it further on x86. -# On x86 clz it is implemented as bitscanreverse then xor and we need to again xor/sub. -# We need the bsr instructions so we xor again hoping for the compiler to only keep 1. +# We reuse bitops from Nim standard lib, and expand it for multi-precision int. +# MpInt rely on no undefined behaviour as often we scan 0. (if 1 is stored in a uint128 for example) +# Also countLeadingZeroBits must return the size of the type and not 0 like in the stdlib -proc bit_length*(x: SomeInteger): int {.noSideEffect.}= - when nimvm: - when sizeof(x) <= 4: result = if x == 0: 0 else: fastlog2_nim(x.uint32) - else: result = if x == 0: 0 else: fastlog2_nim(x.uint64) - else: - when useGCC_builtins: - when sizeof(x) <= 4: result = if x == 0: 0 else: builtin_clz(x.uint32) xor 31.cint - else: result = if x == 0: 0 else: builtin_clzll(x.uint64) xor 63.cint - elif useVCC_builtins: - when sizeof(x) <= 4: - result = if x == 0: 0 else: vcc_scan_impl(bitScanReverse, x.culong) - elif arch64: - result = if x == 0: 0 else: vcc_scan_impl(bitScanReverse64, x.uint64) - else: - result = if x == 0: 0 else: fastlog2_nim(x.uint64) - elif useICC_builtins: - when sizeof(x) <= 4: - result = if x == 0: 0 else: icc_scan_impl(bitScanReverse, x.uint32) - elif arch64: - result = if x == 0: 0 else: icc_scan_impl(bitScanReverse64, x.uint64) - else: - result = if x == 0: 0 else: fastlog2_nim(x.uint64) - else: - when sizeof(x) <= 4: - result = if x == 0: 0 else: fastlog2_nim(x.uint32) - else: - result = if x == 0: 0 else: fastlog2_nim(x.uint64) - - -proc bit_length*(n: MpUintImpl): int {.noSideEffect.}= - ## Calculates how many bits are necessary to represent the number - - const maxHalfRepr = n.lo.type.sizeof * 8 - 1 - - # Changing the following to an if expression somehow transform the whole ASM to 5 branches - # instead of the 4 expected (with the inline ASM from bit_length_impl) - # Also there does not seems to be a way to generate a conditional mov - let hi_bitlen = n.hi.bit_length - result = if hi_bitlen == 0: n.lo.bit_length - else: hi_bitlen + maxHalfRepr - - -proc countLeadingZeroBits*(x: MpUintImpl): int {.inline, nosideeffect.} = +func countLeadingZeroBits*(n: MpUintImpl): int {.inline.} = ## Returns the number of leading zero bits in integer. - const maxHalfRepr = size_mpuintimpl(x) div 2 + const maxHalfRepr = getSize(n) div 2 - let hi_clz = x.hi.countLeadingZeroBits + let hi_clz = n.hi.countLeadingZeroBits result = if hi_clz == maxHalfRepr: - x.lo.countLeadingZeroBits + maxHalfRepr + n.lo.countLeadingZeroBits + maxHalfRepr else: hi_clz + +func bit_length*(n: SomeInteger): int {.inline.}= + ## Calculates how many bits are necessary to represent the number + result = getSize(n) - n.countLeadingZeroBits diff --git a/src/private/conversion.nim b/src/private/conversion.nim index 8b90193..5fcb2ec 100644 --- a/src/private/conversion.nim +++ b/src/private/conversion.nim @@ -7,14 +7,14 @@ # # at your option. This file may not be copied, modified, or distributed except according to those terms. -import ./uint_type, ./size_mpuintimpl, +import ./uint_type, macros -proc initMpUintImpl*[InType, OutType](x: InType, _: typedesc[OutType]): OutType {.noSideEffect.} = +func initMpUintImpl*[InType, OutType](x: InType, _: typedesc[OutType]): OutType {.inline.} = const - size_in = size_mpuintimpl(x) - size_out = size_mpuintimpl(result) + size_in = getSize(x) + size_out = getSize(result) static: assert size_out >= size_in, "The result type size should be equal or bigger than the input type size" @@ -26,23 +26,27 @@ proc initMpUintImpl*[InType, OutType](x: InType, _: typedesc[OutType]): OutType else: result.lo = initMpUintImpl(x, type result.lo) -proc toSubtype*[T: SomeInteger](b: bool, _: typedesc[T]): T {.noSideEffect, inline.}= +func toSubtype*[T: SomeInteger](b: bool, _: typedesc[T]): T {.inline.}= b.T -proc toSubtype*[T: MpUintImpl](b: bool, _: typedesc[T]): T {.noSideEffect, inline.}= +func toSubtype*[T: MpUintImpl](b: bool, _: typedesc[T]): T {.inline.}= type SubTy = type result.lo result.lo = toSubtype(b, SubTy) -proc zero*[T: BaseUint](_: typedesc[T]): T {.noSideEffect, inline.}= +func zero*[T: BaseUint](_: typedesc[T]): T {.inline.}= discard -proc one*[T: BaseUint](_: typedesc[T]): T {.noSideEffect, inline.}= +func one*[T: BaseUint](_: typedesc[T]): T {.inline.}= when T is SomeUnsignedInt: result = T(1) else: - result.lo = one(type result.lo) + let r_ptr = cast[ptr array[getSize(result) div 8, byte]](result.addr) + when system.cpuEndian == bigEndian: + r_ptr[0] = 1 + else: + r_ptr[r_ptr[].len - 1] = 1 -proc toUint*(n: MpUIntImpl): auto {.noSideEffect, inline.}= +func toUint*(n: MpUIntImpl): auto {.inline.}= ## Casts a multiprecision integer to an uint of the same size # TODO: uint128 support @@ -57,11 +61,11 @@ proc toUint*(n: MpUIntImpl): auto {.noSideEffect, inline.}= else: raise newException("Unreachable. MpUInt must be 16-bit minimum and a power of 2") -proc toUint*(n: SomeUnsignedInt): SomeUnsignedInt {.noSideEffect, inline.}= +func toUint*(n: SomeUnsignedInt): SomeUnsignedInt {.inline.}= ## No-op overload of multi-precision int casting n -proc asDoubleUint*(n: BaseUint): auto {.noSideEffect, inline.} = +func asDoubleUint*(n: BaseUint): auto {.inline.} = ## Convert an integer or MpUint to an uint with double the size type Double = ( @@ -73,7 +77,7 @@ proc asDoubleUint*(n: BaseUint): auto {.noSideEffect, inline.} = n.toUint.Double -proc toMpUintImpl*(n: uint16|uint32|uint64): auto {.noSideEffect, inline.} = +func toMpUintImpl*(n: uint16|uint32|uint64): auto {.inline.} = ## Cast an integer to the corresponding size MpUintImpl # Sometimes direct casting doesn't work and we must cast through a pointer @@ -84,6 +88,6 @@ proc toMpUintImpl*(n: uint16|uint32|uint64): auto {.noSideEffect, inline.} = elif n is uint16: return (cast[ptr [MpUintImpl[uint8]]](unsafeAddr n))[] -proc toMpUintImpl*(n: MpUintImpl): MpUintImpl {.noSideEffect, inline.} = +func toMpUintImpl*(n: MpUintImpl): MpUintImpl {.inline.} = ## No op n diff --git a/src/private/primitive_divmod.nim b/src/private/primitive_divmod.nim deleted file mode 100644 index 525f6d5..0000000 --- a/src/private/primitive_divmod.nim +++ /dev/null @@ -1,12 +0,0 @@ -# Mpint -# Copyright 2018 Status Research & Development GmbH -# Licensed under either of -# -# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) -# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) -# -# at your option. This file may not be copied, modified, or distributed except according to those terms. - -proc divmod*(x, y: SomeInteger): tuple[quot, rem: SomeInteger] {.noSideEffect, inline.}= - # hopefully the compiler fuse that in a single op - (x div y, x mod y) diff --git a/src/private/size_mpuintimpl.nim b/src/private/size_mpuintimpl.nim deleted file mode 100644 index 68e2b74..0000000 --- a/src/private/size_mpuintimpl.nim +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2018 Status Research & Development GmbH -# Licensed under either of -# -# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) -# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) -# -# at your option. This file may not be copied, modified, or distributed except according to those terms. - -import ./uint_type, macros - - -proc size_mpuintimpl*(x: NimNode): static[int] = - - # Size of doesn't always work at compile-time, pending PR https://github.com/nim-lang/Nim/pull/5664 - - var multiplier = 1 - var node = x.getTypeInst - - while node.kind == nnkBracketExpr: - assert eqIdent(node[0], "MpuintImpl") - multiplier *= 2 - node = node[1] - - # node[1] has the type - # size(node[1]) * multiplier is the size in byte - - # For optimization we cast to the biggest possible uint - result = if eqIdent(node, "uint64"): multiplier * 64 - elif eqIdent(node, "uint32"): multiplier * 32 - elif eqIdent(node, "uint16"): multiplier * 16 - else: multiplier * 8 - -macro size_mpuintimpl*(x: typed): untyped = - let size = size_mpuintimpl(x) - result = quote do: - `size` diff --git a/src/private/stdlib_bitops.nim b/src/private/stdlib_bitops.nim index 6d3d165..e6040e2 100644 --- a/src/private/stdlib_bitops.nim +++ b/src/private/stdlib_bitops.nim @@ -24,6 +24,13 @@ ## At this time only `fastLog2`, `firstSetBit, `countLeadingZeroBits`, `countTrailingZeroBits` ## may return undefined and/or platform dependant value if given invalid input. + +# Bitops from the standard lib modified for MpInt use. +# - No undefined behaviour or flag needed +# - Note that for CountLeadingZero, it returns sizeof(input) * 8 +# instead of 0 + + const useBuiltins* = not defined(noIntrinsicsBitOpts) # const noUndefined* = defined(noUndefinedBitOpts) const useGCC_builtins* = (defined(gcc) or defined(llvm_gcc) or defined(clang)) and useBuiltins @@ -32,7 +39,7 @@ const useVCC_builtins* = defined(vcc) and useBuiltins const arch64* = sizeof(int) == 8 -proc fastlog2_nim*(x: uint32): int {.inline, nosideeffect.} = +func fastlog2_nim(x: uint32): int {.inline.} = ## Quickly find the log base 2 of a 32-bit or less integer. # https://graphics.stanford.edu/%7Eseander/bithacks.html#IntegerLogDeBruijn # https://stackoverflow.com/questions/11376288/fast-computing-of-log2-for-64-bit-integers @@ -46,7 +53,7 @@ proc fastlog2_nim*(x: uint32): int {.inline, nosideeffect.} = v = v or v shr 16 result = lookup[uint32(v * 0x07C4ACDD'u32) shr 27].int -proc fastlog2_nim*(x: uint64): int {.inline, nosideeffect.} = +func fastlog2_nim(x: uint64): int {.inline.} = ## Quickly find the log base 2 of a 64-bit integer. # https://graphics.stanford.edu/%7Eseander/bithacks.html#IntegerLogDeBruijn # https://stackoverflow.com/questions/11376288/fast-computing-of-log2-for-64-bit-integers @@ -89,15 +96,14 @@ elif useICC_builtins: discard fnc(index.addr, v) index.int - -proc countLeadingZeroBits*(x: SomeInteger): int {.inline, nosideeffect.} = +func countLeadingZeroBits*(x: SomeInteger): int {.inline.} = ## Returns the number of leading zero bits in integer. ## If `x` is zero, when ``noUndefinedBitOpts`` is set, result is 0, ## otherwise result is undefined. # when noUndefined: if x == 0: - return sizeof(x) * 8 + return sizeof(x) * 8 # Note this differes from the stdlib which returns 0 when nimvm: when sizeof(x) <= 4: result = sizeof(x)*8 - 1 - fastlog2_nim(x.uint32) diff --git a/src/private/uint_addsub.nim b/src/private/uint_addsub.nim new file mode 100644 index 0000000..a8e1c95 --- /dev/null +++ b/src/private/uint_addsub.nim @@ -0,0 +1,38 @@ +# Mpint +# Copyright 2018 Status Research & Development GmbH +# Licensed under either of +# +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) +# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) +# +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import ./bithacks, ./conversion, + ./uint_type, + ./uint_comparison, + ./uint_bitwise_ops + +# ############ Addition & Substraction ############ # + +proc `+=`*(x: var MpUintImpl, y: MpUintImpl) {.noSideEffect, inline.}= + ## In-place addition for multi-precision unsigned int + + type SubTy = type x.lo + x.lo += y.lo + x.hi += (x.lo < y.lo).toSubtype(SubTy) + y.hi + +proc `+`*(x, y: MpUintImpl): MpUintImpl {.noSideEffect, noInit, inline.}= + # Addition for multi-precision unsigned int + result = x + result += y + +proc `-`*(x, y: MpUintImpl): MpUintImpl {.noSideEffect, noInit, inline.}= + # Substraction for multi-precision unsigned int + + type SubTy = type x.lo + result.lo = x.lo - y.lo + result.hi = x.hi - y.hi - (x.lo < y.lo).toSubtype(SubTy) + +proc `-=`*(x: var MpUintImpl, y: MpUintImpl) {.noSideEffect, inline.}= + ## In-place substraction for multi-precision unsigned int + x = x - y diff --git a/src/private/uint_binary_ops.nim b/src/private/uint_binary_ops.nim deleted file mode 100644 index 020a4aa..0000000 --- a/src/private/uint_binary_ops.nim +++ /dev/null @@ -1,118 +0,0 @@ -# Mpint -# Copyright 2018 Status Research & Development GmbH -# Licensed under either of -# -# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) -# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) -# -# at your option. This file may not be copied, modified, or distributed except according to those terms. - -import ./bithacks, ./conversion, - ./uint_type, - ./uint_comparison, - ./uint_bitwise_ops, - ./size_mpuintimpl - -# ############ Addition & Substraction ############ # - -proc `+=`*(x: var MpUintImpl, y: MpUintImpl) {.noSideEffect, inline.}= - ## In-place addition for multi-precision unsigned int - # - # Optimized assembly should contain adc instruction (add with carry) - # Clang on MacOS does with the -d:release switch and MpUint[uint32] (uint64) - type SubTy = type x.lo - x.lo += y.lo - x.hi += (x.lo < y.lo).toSubtype(SubTy) + y.hi - -proc `+`*(x, y: MpUintImpl): MpUintImpl {.noSideEffect, noInit, inline.}= - # Addition for multi-precision unsigned int - result = x - result += y - -proc `-=`*(x: var MpUintImpl, y: MpUintImpl) {.noSideEffect, inline.}= - ## In-place substraction for multi-precision unsigned int - # - # Optimized assembly should contain sbb instruction (substract with borrow) - # Clang on MacOS does with the -d:release switch and MpUint[uint32] (uint64) - type SubTy = type x.lo - x.hi -= (x.lo < y.lo).toSubtype(SubTy) + y.hi - x.lo -= y.lo - -proc `-`*(x, y: MpUintImpl): MpUintImpl {.noSideEffect, noInit, inline.}= - # Substraction for multi-precision unsigned int - result = x - result -= y - - -# ################### Multiplication ################### # - -proc naiveMulImpl[T: MpUintImpl](x, y: T): MpUintImpl[T] {.noSideEffect, noInit, inline.} - # Forward declaration - -proc naiveMul*[T: BaseUint](x, y: T): MpUintImpl[T] {.noSideEffect, noInit, inline.}= - ## Naive multiplication algorithm with extended precision - - const size = size_mpuintimpl(x) - - when size in {8, 16, 32}: - # Use types twice bigger to do the multiplication - cast[type result](x.asDoubleUint * y.asDoubleUint) - - elif size == 64: # uint64 or MpUint[uint32] - # We cannot double uint64 to uint128 - cast[type result](naiveMulImpl(x.toMpUintImpl, y.toMpUintImpl)) - else: - # Case: at least uint128 * uint128 --> uint256 - cast[type result](naiveMulImpl(x, y)) - - -proc naiveMulImpl[T: MpUintImpl](x, y: T): MpUintImpl[T] {.noSideEffect, noInit, inline.}= - # See details at - # https://en.wikipedia.org/wiki/Karatsuba_algorithm - # https://locklessinc.com/articles/256bit_arithmetic/ - # https://www.miracl.com/press/missing-a-trick-karatsuba-variations-michael-scott - # - # We use the naive school grade multiplication instead of Karatsuba I.e. - # z1 = x.hi * y.lo + x.lo * y.hi (Naive) = (x.lo - x.hi)(y.hi - y.lo) + z0 + z2 (Karatsuba) - # - # On modern architecture: - # - addition and multiplication have the same cost - # - Karatsuba would require to deal with potentially negative intermediate result - # and introduce branching - # - More total operations means more register moves - - const halfSize = size_mpuintimpl(x) div 2 - let - z0 = naiveMul(x.lo, y.lo) - tmp = naiveMul(x.hi, y.lo) - - var z1 = tmp - z1 += naiveMul(x.hi, y.lo) - let z2 = (z1 < tmp).toSubtype(T) + naiveMul(x.hi, y.hi) - - let tmp2 = initMpUintImpl(z1.lo shl halfSize, T) - result.lo = tmp2 - result.lo += z0 - result.hi = (result.lo < tmp2).toSubtype(T) + z2 + initMpUintImpl(z1.hi, type result.hi) - -proc `*`*(x, y: MpUintImpl): MpUintImpl {.noSideEffect, noInit.}= - ## Multiplication for multi-precision unsigned uint - # - # For our representation, it is similar to school grade multiplication - # Consider hi and lo as if they were digits - # - # 12 - # X 15 - # ------ - # 10 lo*lo -> z0 - # 5 hi*lo -> z1 - # 2 lo*hi -> z1 - # 10 hi*hi -- z2 - # ------ - # 180 - # - # If T is a type - # For T * T --> T we don't need to compute z2 as it always overflow - # For T * T --> 2T (uint64 * uint64 --> uint128) we use extra precision multiplication - result = naiveMul(x.lo, y.lo) - result.hi += (naiveMul(x.hi, y.lo) + naiveMul(x.lo, y.hi)).lo diff --git a/src/private/uint_bitwise_ops.nim b/src/private/uint_bitwise_ops.nim index aadc0bf..9a62c2e 100644 --- a/src/private/uint_bitwise_ops.nim +++ b/src/private/uint_bitwise_ops.nim @@ -7,7 +7,7 @@ # # at your option. This file may not be copied, modified, or distributed except according to those terms. -import ./uint_type, ./size_mpuintimpl, ./conversion +import ./uint_type, ./conversion func `not`*(x: MpUintImpl): MpUintImpl {.noInit, inline.}= @@ -40,7 +40,7 @@ func `shl`*(x: MpUintImpl, y: SomeInteger): MpUintImpl {.inline.}= # TODO: would it be better to reimplement this using an array of bytes/uint64 # That opens up to endianness issues. - const halfSize = size_mpuintimpl(x) div 2 + const halfSize = getSize(x) div 2 let defect = halfSize - int(y) if y == 0: @@ -55,7 +55,7 @@ func `shl`*(x: MpUintImpl, y: SomeInteger): MpUintImpl {.inline.}= func `shr`*(x: MpUintImpl, y: SomeInteger): MpUintImpl {.inline.}= ## Compute the `shift right` operation of x and y - const halfSize = size_mpuintimpl(x) div 2 + const halfSize = getSize(x) div 2 if y == 0: return x diff --git a/src/private/uint_comparison.nim b/src/private/uint_comparison.nim index b472326..2d0952d 100644 --- a/src/private/uint_comparison.nim +++ b/src/private/uint_comparison.nim @@ -7,45 +7,94 @@ # # at your option. This file may not be copied, modified, or distributed except according to those terms. -import ./uint_type, ./size_mpuintimpl, macros +import ./uint_type, macros -macro cast_optim(x: typed): untyped = - let size = size_mpuintimpl(x) +macro optim(x: typed): untyped = + let size = getSize(x) if size > 64: result = quote do: - cast[array[`size` div 64, uint64]](`x`) + array[`size` div 64, uint64] elif size == 64: result = quote do: - cast[uint64](`x`) + uint64 elif size == 32: result = quote do: - cast[uint32](`x`) + uint32 elif size == 16: result = quote do: - cast[uint16](`x`) + uint16 elif size == 8: result = quote do: - cast[uint8](`x`) + uint8 else: error "Unreachable path reached" -proc isZero*(n: SomeUnsignedInt): bool {.noSideEffect,inline.} = +func isZero*(n: SomeUnsignedInt): bool {.inline.} = n == 0 -proc isZero*(n: MpUintImpl): bool {.noSideEffect,inline.} = - n == (type n)() +func isZero*(n: MpUintImpl): bool {.inline.} = -proc `<`*(x, y: MpUintImpl): bool {.noSideEffect, noInit, inline.}= - (x.hi < y.hi) or ((x.hi == y.hi) and x.lo < y.lo) + when optim(`n`) is array: + for val in cast[optim(n)](n): + if val != 0: + return false + return true + else: + cast[optim(n)](n) == 0 -proc `==`*(x, y: MpuintImpl): bool {.noSideEffect, noInit, inline.}= +func `<`*(x, y: MpUintImpl): bool {.noInit, inline.}= + + when optim(x) is array: + let + x_ptr = cast[ptr optim(x)](x.unsafeaddr) + y_ptr = cast[ptr optim(y)](y.unsafeaddr) + + when system.cpuEndian == bigEndian: + for i in 0.. r): + if (not carry) and (d > ca0): q -= one(type q) r += b + # if there was no carry if r > b: q -= one(type q) r += b - r1 = r.hi - r0 = r.lo - -template sub_ddmmss[T](sh, sl, ah, al, bh, bl: T) = - sl = al - bl - sh = ah - bh - (al < bl).T - -func lo[T:SomeUnsignedInt](x: T): T {.inline.} = - const - p = T.sizeof * 8 div 2 - base = 1 shl p - mask = base - 1 - result = x and mask - -func hi[T:SomeUnsignedInt](x: T): T {.inline.} = - const - p = T.sizeof * 8 div 2 - result = x shr p - -func umul_ppmm[T](w1, w0: var T, u, v: T) = - - const - p = (T.sizeof * 8 div 2) - base = 1 shl p +proc div3n2n[T: SomeUnsignedInt]( + q: var T, + r: var MpUintImpl[T], + a2, a1, a0: T, + b: MpUintImpl[T]) = var - x0, x1, x2, x3: T - - let - ul = u.lo - uh = u.hi - vl = v.lo - vh = v.hi - - x0 = ul * vl - x1 = ul * vh - x2 = uh * vl - x3 = uh * vh - - x1 += x0.hi # This can't carry - x1 += x2 # but this can - if x1 < x2: # if carry, add it to x3 - x3 += base - - w1 = x3 + x1.hi - w0 = (x1 shl p) + x0.lo - - -proc div3n2n( q, r1, r0: var SomeUnsignedInt, - a2, a1, a0: SomeUnsignedInt, - b1, b0: SomeUnsignedInt) {.inline.}= - mixin div2n1n - - type T = type q - - var - c, d1, d0: T + c: T + d: MpUintImpl[T] carry: bool - if a2 < b1: - div2n1n(q, c, a2, a1, b1) + if a2 < b.hi: + div2n1n(q, c, a2, a1, b.hi) else: q = 0.T - 1.T # We want 0xFFFFF .... - c = a1 + b1 + c = a1 + b.hi if c < a1: carry = true - umul_ppmm(d1, d0, q, b0) - sub_ddmmss(r1, r0, c, a0, d1, d0) + extPrecMul[T](d, q, b.lo) + let ca0 = MpUintImpl[T](hi: c, lo: a0) + r = ca0 - d - if (not carry) and ((d1 > c) or ((d1 == c) and (d0 > a0))): - q -= 1.T - r0 += b0 - r1 += b1 - if r0 < b0: - inc r1 + if (not carry) and d > ca0: + dec q + r += b - if (r1 > b1) or ((r1 == b1) and (r0 >= b0)): - q -= 1.T - r0 += b0 - r1 += b1 - if r0 < b0: - inc r1 + # if there was no carry + if r > b: + dec q + r += b -func div2n1n(q, r: var MpUintImpl, ah, al, b: MpUintImpl) {.inline.} = +func div2n1n(q, r: var MpUintImpl, ah, al, b: MpUintImpl) = # assert countLeadingZeroBits(b) == 0, "Divisor was not normalized" var s: MpUintImpl - div3n2n(q.hi, s.hi, s.lo, ah.hi, ah.lo, al.hi, b.hi, b.lo) - div3n2n(q.lo, r.hi, r.lo, s.hi, s.lo, al.lo, b.hi, b.lo) + div3n2n(q.hi, s, ah.hi, ah.lo, al.hi, b) + div3n2n(q.lo, r, s.hi, s.lo, al.lo, b) -func div2n1n[T: SomeunsignedInt](q, r: var T, n_hi, n_lo, d: T) {.inline.} = +func div2n1n[T: SomeunsignedInt](q, r: var T, n_hi, n_lo, d: T) = # assert countLeadingZeroBits(d) == 0, "Divisor was not normalized" const - size = size_mpuintimpl(q) + size = getSize(q) halfSize = size div 2 halfMask = (1.T shl halfSize) - 1.T @@ -192,10 +145,10 @@ func div2n1n[T: SomeunsignedInt](q, r: var T, n_hi, n_lo, d: T) {.inline.} = # Fix the reminder, we're at most 2 iterations off if r < m: - q -= 1.T + dec q r += d_hi if r >= d_hi and r < m: - q -= 1.T + dec q r += d_hi r -= m (q, r) @@ -215,23 +168,114 @@ func div2n1n[T: SomeunsignedInt](q, r: var T, n_hi, n_lo, d: T) {.inline.} = q = (q1 shl halfSize) or q2 r = r2 -func divmod*[T](x, y: MpUintImpl[T]): tuple[quot, rem: MpUintImpl[T]] = +func divmodBZ[T](x, y: MpUintImpl[T], q, r: var MpUintImpl[T])= - # Normalization - assert y.isZero.not() + assert y.isZero.not() # This should be checked on release mode in the divmod caller proc - const halfSize = size_mpuintimpl(x) div 2 - let clz = countLeadingZeroBits(y) + if y.hi.isZero: + # Shortcut if divisor is smaller than half the size of the type - let - xx = MpUintImpl[type x](lo: x) shl clz - yy = y shl clz + # Normalize + let + clz = countLeadingZeroBits(y.lo) + xx = x shl clz + yy = y.lo shl clz - # Compute - div2n1n(result.quot, result.rem, xx.hi, xx.lo, yy) + if x.hi < y.lo: + # If y is smaller than the base, normalizing x does not overflow. + # Compute directly + div2n1n(q.lo, r.lo, xx.hi, xx.lo, yy) + # Undo normalization + r.lo = r.lo shr clz + else: + # Normalizing x overflowed, we need to compute the high remainder first + (q.hi, r.hi) = divmod(x.hi, y.lo) - # Undo normalization - result.rem = result.rem shr clz + # Normalize the remainder. (x.lo is already normalized) + r.hi = r.hi shl clz + + # Compute + div2n1n(q.lo, r.lo, r.hi, xx.lo, yy) + + # Undo normalization + r.lo = r.lo shr clz + + # Given size n, dividing a 2n number by a 1n normalized number + # always gives a 1n remainder. + r.hi = zero(T) + + else: # General case + # Normalization + let clz = countLeadingZeroBits(y) + + let + xx = MpUintImpl[type x](lo: x) shl clz + yy = y shl clz + + # Compute + div2n1n(q, r, xx.hi, xx.lo, yy) + + # Undo normalization + r = r shr clz + +func divmodBS(x, y: MpUintImpl, q, r: var MpuintImpl) = + ## Division for multi-precision unsigned uint + ## Implementation through binary shift division + + assert y.isZero.not() # This should be checked on release mode in the divmod caller proc + + type SubTy = type x.lo + + var + shift = x.countLeadingZeroBits - y.countLeadingZeroBits + d = y shl shift + + r = x + + while shift >= 0: + q += q + if r >= d: + r -= d + q.lo = q.lo or one(SubTy) + + d = d shr 1 + dec(shift) + +const BinaryShiftThreshold = 8 # If the difference in bit-length is below 8 + # binary shift is probably faster + +func divmod*[T](x, y: MpUintImpl[T]): tuple[quot, rem: MpUintImpl[T]]= + + let x_clz = x.countLeadingZeroBits + let y_clz = y.countLeadingZeroBits + + # We short-circuit division depending on special-cases. + # TODO: Constant-time division + if unlikely(y.isZero): + raise newException(DivByZeroError, "You attempted to divide by zero") + elif y_clz == (getSize(y) - 1): + # y is one + result.quot = x + elif (x.hi or y.hi).isZero: + # If computing just on the low part is enough + (result.quot.lo, result.rem.lo) = divmod(x.lo, y.lo) + elif (y and (y - one(type y))).isZero: + # y is a power of 2. (this also matches 0 but it was eliminated earlier) + # TODO. Would it be faster to use countTrailingZero (ctz) + clz == size(y) - 1? + # Especially because we shift by ctz after. + # It is a bit tricky with recursive types. An empty n.lo means 0 or sizeof(n.lo) + let y_ctz = getSize(y) - y_clz - 1 + result.quot = x shr y_ctz + result.rem = y_ctz.initMpUintImpl(MpUintImpl[T]) + result.rem = result.rem and x + elif x == y: + result.quot.lo = one(T) + elif x < y: + result.rem = x + elif (y_clz - x_clz) < BinaryShiftThreshold: + divmodBS(x, y, result.quot, result.rem) + else: + divmodBZ(x, y, result.quot, result.rem) func `div`*(x, y: MpUintImpl): MpUintImpl {.inline.} = ## Division operation for multi-precision unsigned uint @@ -280,31 +324,3 @@ func `mod`*(x, y: MpUintImpl): MpUintImpl {.inline.} = # - Google Abseil: https://github.com/abseil/abseil-cpp/tree/master/absl/numeric # - Crypto libraries like libsecp256k1, OpenSSL, ... though they are not generics. (uint256 only for example) # Note: GMP/MPFR are GPL. The papers can be used but not their code. - -# ###################################################################### -# School division - -# proc divmod*(x, y: MpUintImpl): tuple[quot, rem: MpUintImpl] {.noSideEffect.}= -# ## Division for multi-precision unsigned uint -# ## Returns quotient + reminder in a (quot, rem) tuple -# # -# # Implementation through binary shift division -# if unlikely(y.isZero): -# raise newException(DivByZeroError, "You attempted to divide by zero") - -# type SubTy = type x.lo - -# var -# shift = x.bit_length - y.bit_length -# d = y shl shift - -# result.rem = x - -# while shift >= 0: -# result.quot += result.quot -# if result.rem >= d: -# result.rem -= d -# result.quot.lo = result.quot.lo or one(SubTy) - -# d = d shr 1 -# dec(shift) diff --git a/src/private/uint_mul.nim b/src/private/uint_mul.nim new file mode 100644 index 0000000..4d64fb7 --- /dev/null +++ b/src/private/uint_mul.nim @@ -0,0 +1,138 @@ +# Mpint +# Copyright 2018 Status Research & Development GmbH +# Licensed under either of +# +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) +# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) +# +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import ./conversion, + ./uint_type, + ./uint_comparison, + ./uint_addsub + +# ################### Multiplication ################### # + +func lo[T:SomeUnsignedInt](x: T): T {.inline.} = + const + p = T.sizeof * 8 div 2 + base = 1 shl p + mask = base - 1 + result = x and mask + +func hi[T:SomeUnsignedInt](x: T): T {.inline.} = + const + p = T.sizeof * 8 div 2 + result = x shr p + +# No generic, somehow Nim is given ambiguous call with the T: MpUintImpl overload +func extPrecMul*(result: var MpUintImpl[uint8], x, y: uint8) = + ## Extended precision multiplication + result = cast[type result](x.asDoubleUint * y.asDoubleUint) + +func extPrecMul*(result: var MpUintImpl[uint16], x, y: uint16) = + ## Extended precision multiplication + result = cast[type result](x.asDoubleUint * y.asDoubleUint) + +func extPrecMul*(result: var MpUintImpl[uint32], x, y: uint32) = + ## Extended precision multiplication + result = cast[type result](x.asDoubleUint * y.asDoubleUint) + +func extPrecAddMul[T: uint8 or uint16 or uint32](result: var MpUintImpl[T], x, y: T) = + ## Extended precision fused in-place addition & multiplication + result += cast[type result](x.asDoubleUint * y.asDoubleUint) + +template extPrecMulImpl(result: var MpUintImpl[uint64], op: untyped, u, v: uint64) = + const + p = 64 div 2 + base = 1 shl p + + var + x0, x1, x2, x3: uint64 + + let + ul = u.lo + uh = u.hi + vl = v.lo + vh = v.hi + + x0 = ul * vl + x1 = ul * vh + x2 = uh * vl + x3 = uh * vh + + x1 += x0.hi # This can't carry + x1 += x2 # but this can + if x1 < x2: # if carry, add it to x3 + x3 += base + + op(result.hi, x3 + x1.hi) + op(result.lo, (x1 shl p) or x0.lo) + +func extPrecMul*(result: var MpUintImpl[uint64], u, v: uint64) = + ## Extended precision multiplication + extPrecMulImpl(result, `=`, u, v) + +func extPrecAddMul(result: var MpUintImpl[uint64], u, v: uint64) = + ## Extended precision fused in-place addition & multiplication + extPrecMulImpl(result, `+=`, u, v) + +func extPrecMul*[T](result: var MpUintImpl[MpUintImpl[T]], x, y: MpUintImpl[T]) = + # See details at + # https://en.wikipedia.org/wiki/Karatsuba_algorithm + # https://locklessinc.com/articles/256bit_arithmetic/ + # https://www.miracl.com/press/missing-a-trick-karatsuba-variations-michael-scott + # + # We use the naive school grade multiplication instead of Karatsuba I.e. + # z1 = x.hi * y.lo + x.lo * y.hi (Naive) = (x.lo - x.hi)(y.hi - y.lo) + z0 + z2 (Karatsuba) + # + # On modern architecture: + # - addition and multiplication have the same cost + # - Karatsuba would require to deal with potentially negative intermediate result + # and introduce branching + # - More total operations means more register moves + + var z1: MpUintImpl[T] + + # Low part - z0 + extPrecMul(result.lo, x.lo, y.lo) + + # Middle part - z1 + extPrecMul(z1, x.hi, y.lo) + let carry_check = z1 + extPrecAddMul(z1, x.lo, y.hi) + if z1 < carry_check: + result.hi.lo = one(T) + + # High part - z2 + result.hi.lo += z1.hi + extPrecAddMul(result.hi, x.hi, y.hi) + + # Finalize low part + result.lo.hi += z1.lo + if result.lo.hi < z1.lo: + result.hi += one(MpUintImpl[T]) + +func `*`*[T](x, y: MpUintImpl[T]): MpUintImpl[T] {.inline.}= + ## Multiplication for multi-precision unsigned uint + # + # For our representation, it is similar to school grade multiplication + # Consider hi and lo as if they were digits + # + # 12 + # X 15 + # ------ + # 10 lo*lo -> z0 + # 5 hi*lo -> z1 + # 2 lo*hi -> z1 + # 10 hi*hi -- z2 + # ------ + # 180 + # + # If T is a type + # For T * T --> T we don't need to compute z2 as it always overflow + # For T * T --> 2T (uint64 * uint64 --> uint128) we use extra precision multiplication + + extPrecMul(result, x.lo, y.lo) + result.hi += x.lo * y.hi + x.hi * y.lo diff --git a/src/private/uint_type.nim b/src/private/uint_type.nim index 14064c6..e78fd95 100644 --- a/src/private/uint_type.nim +++ b/src/private/uint_type.nim @@ -51,6 +51,32 @@ else: else: error "Fatal: unreachable" +proc getSize*(x: NimNode): static[int] = + + # Size of doesn't always work at compile-time, pending PR https://github.com/nim-lang/Nim/pull/5664 + + var multiplier = 1 + var node = x.getTypeInst + + while node.kind == nnkBracketExpr: + assert eqIdent(node[0], "MpuintImpl") + multiplier *= 2 + node = node[1] + + # node[1] has the type + # size(node[1]) * multiplier is the size in byte + + # For optimization we cast to the biggest possible uint + result = if eqIdent(node, "uint64"): multiplier * 64 + elif eqIdent(node, "uint32"): multiplier * 32 + elif eqIdent(node, "uint16"): multiplier * 16 + else: multiplier * 8 + +macro getSize*(x: typed): untyped = + let size = getSize(x) + result = quote do: + `size` + type # ### Private ### # # If this is not in the same type section diff --git a/src/uint_init.nim b/src/uint_init.nim index 610347c..78b161e 100644 --- a/src/uint_init.nim +++ b/src/uint_init.nim @@ -14,7 +14,7 @@ import ./private/bithacks, ./private/conversion, import typetraits -proc initMpUint*[T: SomeInteger](n: T, bits: static[int]): MpUint[bits] {.noSideEffect.} = +func initMpUint*[T: SomeInteger](n: T, bits: static[int]): MpUint[bits] {.inline.}= assert n >= 0.T when result.data is MpuintImpl: let len = n.bit_length diff --git a/src/uint_public.nim b/src/uint_public.nim index ba75329..e4aa440 100644 --- a/src/uint_public.nim +++ b/src/uint_public.nim @@ -15,14 +15,14 @@ type UInt256* = MpUint[256] template make_conv(conv_name: untyped, size: int): untyped = - proc `convname`*(n: SomeInteger): MpUint[size] {.noSideEffect, inline, noInit.}= + func `convname`*(n: SomeInteger): MpUint[size] {.inline, noInit.}= n.initMpUint(size) make_conv(u128, 128) make_conv(u256, 256) template make_unary(op, ResultTy): untyped = - proc `op`*(x: MpUint): ResultTy {.noInit, inline, noSideEffect.} = + func `op`*(x: MpUint): ResultTy {.noInit, inline.} = when resultTy is MpUint: result.data = op(x.data) else: @@ -30,7 +30,7 @@ template make_unary(op, ResultTy): untyped = export op template make_binary(op, ResultTy): untyped = - proc `op`*(x, y: MpUint): ResultTy {.noInit, inline, noSideEffect.} = + func `op`*(x, y: MpUint): ResultTy {.noInit, inline.} = when ResultTy is MpUint: result.data = op(x.data, y.data) else: @@ -38,31 +38,33 @@ template make_binary(op, ResultTy): untyped = export `op` template make_binary_inplace(op): untyped = - proc `op`*(x: var MpUint, y: MpUint) {.inline, noSideEffect.} = + func `op`*(x: var MpUint, y: MpUint) {.inline.} = op(x.data, y.data) export op -import ./private/uint_binary_ops +import ./private/uint_addsub make_binary(`+`, MpUint) make_binary_inplace(`+=`) make_binary(`-`, MpUint) make_binary_inplace(`-=`) + +import ./private/uint_mul make_binary(`*`, MpUint) -import ./private/primitive_divmod, - ./private/uint_division +import ./private/uint_div make_binary(`div`, MpUint) make_binary(`mod`, MpUint) -proc divmod*(x, y: MpUint): tuple[quot, rem: MpUint] {.noInit, inline, noSideEffect.} = +func divmod*(x, y: MpUint): tuple[quot, rem: MpUint] {.noInit, inline.} = (result.quot.data, result.rem.data) = divmod(x.data, y.data) import ./private/uint_comparison make_binary(`<`, bool) make_binary(`<=`, bool) -proc isZero*(x: MpUint): bool {.inline, noSideEffect.} = isZero x +make_binary(`==`, bool) +func isZero*(x: MpUint): bool {.inline.} = isZero x.data import ./private/uint_bitwise_ops diff --git a/tests/test_comparison.nim b/tests/test_comparison.nim index d419ad5..7f17669 100644 --- a/tests/test_comparison.nim +++ b/tests/test_comparison.nim @@ -14,27 +14,42 @@ suite "Testing comparison operators": a = 10.initMpUint(16) b = 15.initMpUint(16) c = 150'u16 + d = 4.initMpUint(128) shl 64 + e = 4.initMpUint(128) + f = 4.initMpUint(128) shl 65 test "< operator": - check: a < b - check: not (a + b < b) - check: not (a + a + a < b + b) - check: not (a * b < cast[MpUint[16]](c)) + check: + a < b + not (a + b < b) + not (a + a + a < b + b) + not (a * b < cast[MpUint[16]](c)) + e < d + d < f test "<= operator": - check: a <= b - check: not (a + b <= b) - check: a + a + a <= b + b - check: a * b <= cast[MpUint[16]](c) + check: + a <= b + not (a + b <= b) + a + a + a <= b + b + a * b <= cast[MpUint[16]](c) + e <= d + d <= f test "> operator": - check: b > a - check: not (b > a + b) - check: not (b + b > a + a + a) - check: not (cast[Mpuint[16]](c) > a * b) + check: + b > a + not (b > a + b) + not (b + b > a + a + a) + not (cast[Mpuint[16]](c) > a * b) + d > e + f > d test ">= operator": - check: b >= a - check: not (b >= a + b) - check: b + b >= a + a + a - check: cast[MpUint[16]](c) >= a * b + check: + b >= a + not (b >= a + b) + b + b >= a + a + a + cast[MpUint[16]](c) >= a * b + d >= e + f >= d