Implement exponentiation, test mul, split mul/div tests

This commit is contained in:
Mamy André-Ratsimbazafy 2020-09-06 16:27:11 +02:00 committed by jangko
parent 254d4da649
commit dc9e0a43ca
No known key found for this signature in database
GPG Key ID: 31702AE10541E6B9
5 changed files with 153 additions and 137 deletions

View File

@ -1,50 +0,0 @@
# Stint
# Copyright 2018 Status Research & Development GmbH
# Licensed under either of
#
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0)
# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT)
#
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
./datatypes,
./uint_bitwise_ops, ./uint_mul, ./initialization, ./uint_comparison
func pow*(x: UintImpl, y: Natural): UintImpl =
## Compute ``x`` to the power of ``y``,
## ``x`` must be non-negative
# Implementation uses exponentiation by squaring
# See Nim math module: https://github.com/nim-lang/Nim/blob/4ed24aa3eb78ba4ff55aac3008ec3c2427776e50/lib/pure/math.nim#L429
# And Eli Bendersky's blog: https://eli.thegreenplace.net/2009/03/21/efficient-integer-exponentiation-algorithms
var (x, y) = (x, y)
result = one(type x)
while true:
if bool(y and 1): # if y is odd
result = result * x
y = y shr 1
if y == 0:
break
x = x * x
func pow*(x: UintImpl, y: UintImpl): UintImpl =
## Compute ``x`` to the power of ``y``,
## ``x`` must be non-negative
# Implementation uses exponentiation by squaring
# See Nim math module: https://github.com/nim-lang/Nim/blob/4ed24aa3eb78ba4ff55aac3008ec3c2427776e50/lib/pure/math.nim#L429
# And Eli Bendersky's blog: https://eli.thegreenplace.net/2009/03/21/efficient-integer-exponentiation-algorithms
var (x, y) = (x, y)
result = one(type x)
while true:
if y.isOdd:
result = result * x
y = y shr 1
if y.isZero:
break
x = x * x

View File

@ -23,18 +23,22 @@ export StUint
func setZero*(a: var StUint) = func setZero*(a: var StUint) =
## Set ``a`` to 0 ## Set ``a`` to 0
zeroMem(a[0].addr, sizeof(a)) for i in 0 ..< a.limbs.len:
a[i] = 0
func setSmallInt(a: var StUint, k: Word) =
## Set ``a`` to k
when cpuEndian == littleEndian:
a.limbs[0] = k
for i in 1 ..< a.limbs.len:
a.limbs[i] = 0
else:
a.limbs[^1] = k
for i in 0 ..< a.limb.len - 1:
a.limbs[i] = 0
func setOne*(a: var StUint) = func setOne*(a: var StUint) =
## Set ``a`` to 1 setSmallInt(a, 1)
when cpuEndian == littleEndian:
a.limbs[0] = 1
when a.limbs.len > 1:
zeroMem(a.limbs[1].addr, (a.limbs.len - 1) * sizeof(SecretWord))
else:
a.limbs[^1] = 1
when a.limbs.len > 1:
zeroMem(a.limbs[0].addr, (a.len - 1) * sizeof(SecretWord))
func zero*[bits: static[int]](T: typedesc[Stuint[bits]]): T {.inline.} = func zero*[bits: static[int]](T: typedesc[Stuint[bits]]): T {.inline.} =
## Returns the zero of the input type ## Returns the zero of the input type
@ -42,7 +46,7 @@ func zero*[bits: static[int]](T: typedesc[Stuint[bits]]): T {.inline.} =
func one*[bits: static[int]](T: typedesc[Stuint[bits]]): T {.inline.} = func one*[bits: static[int]](T: typedesc[Stuint[bits]]): T {.inline.} =
## Returns the one of the input type ## Returns the one of the input type
result.limbs.setOne() result.setOne()
func high*[bits](_: typedesc[Stuint[bits]]): Stuint[bits] {.inline.} = func high*[bits](_: typedesc[Stuint[bits]]): Stuint[bits] {.inline.} =
for wr in leastToMostSig(result): for wr in leastToMostSig(result):
@ -279,8 +283,52 @@ func `*`*(a, b: Stuint): Stuint =
result.clearExtraBits() result.clearExtraBits()
{.pop.} {.pop.}
# Division & Modulo
# --------------------------------------------------------
# Exponentiation # Exponentiation
# -------------------------------------------------------- # --------------------------------------------------------
{.push raises: [], noInit, gcsafe.}
func pow*(a: Stuint, e: Natural): Stuint =
## Compute ``a`` to the power of ``e``,
## ``e`` must be non-negative
# Implementation uses exponentiation by squaring
# See Nim math module: https://github.com/nim-lang/Nim/blob/4ed24aa3eb78ba4ff55aac3008ec3c2427776e50/lib/pure/math.nim#L429
# And Eli Bendersky's blog: https://eli.thegreenplace.net/2009/03/21/efficient-integer-exponentiation-algorithms
var (a, e) = (a, e)
result.setOne()
while true:
if bool(e and 1): # if y is odd
result = result * a
e = e shr 1
if e == 0:
break
a = a * a
func pow*[aBits, eBits](a: Stuint[aBits], e: Stuint[eBits]): Stuint[aBits] =
## Compute ``x`` to the power of ``y``,
## ``x`` must be non-negative
# Implementation uses exponentiation by squaring
# See Nim math module: https://github.com/nim-lang/Nim/blob/4ed24aa3eb78ba4ff55aac3008ec3c2427776e50/lib/pure/math.nim#L429
# And Eli Bendersky's blog: https://eli.thegreenplace.net/2009/03/21/efficient-integer-exponentiation-algorithms
var (a, e) = (a, e)
result.setOne()
while true:
if e.isOdd:
result = result * a
e = e shr 1
if e.isZero:
break
a = a * a
{.pop.}
# Division & Modulo
# --------------------------------------------------------

View File

@ -9,9 +9,6 @@
import ../stint, unittest, test_helpers import ../stint, unittest, test_helpers
template chkMul(chk: untyped, a, b, c: string, bits: int) =
chk (fromHex(StUint[bits], a) * fromHex(StUint[bits], b)) == fromHex(StUint[bits], c)
template chkDiv(chk: untyped, a, b, c: string, bits: int) = template chkDiv(chk: untyped, a, b, c: string, bits: int) =
chk (fromHex(StUint[bits], a) div fromHex(StUint[bits], b)) == fromHex(StUint[bits], c) chk (fromHex(StUint[bits], a) div fromHex(StUint[bits], b)) == fromHex(StUint[bits], c)
@ -21,41 +18,7 @@ template chkMod(chk: untyped, a, b, c: string, bits: int) =
template chkDivMod(chk: untyped, a, b, c, d: string, bits: int) = template chkDivMod(chk: untyped, a, b, c, d: string, bits: int) =
chk divmod(fromHex(StUint[bits], a), fromHex(StUint[bits], b)) == (fromHex(StUint[bits], c), fromHex(StUint[bits], d)) chk divmod(fromHex(StUint[bits], a), fromHex(StUint[bits], b)) == (fromHex(StUint[bits], c), fromHex(StUint[bits], d))
template testMuldiv(chk, tst: untyped) = template testdivmod(chk, tst: untyped) =
tst "operator `mul`":
chkMul(chk, "0", "3", "0", 8)
chkMul(chk, "1", "3", "3", 8)
chkMul(chk, "64", "3", "2C", 8) # overflow
chkMul(chk, "0", "3", "0", 16)
chkMul(chk, "1", "3", "3", 16)
chkMul(chk, "64", "3", "12C", 16)
chkMul(chk, "1770", "46", "68A0", 16) # overflow
chkMul(chk, "0", "3", "0", 32)
chkMul(chk, "1", "3", "3", 32)
chkMul(chk, "64", "3", "12C", 32)
chkMul(chk, "1770", "46", "668A0", 32)
chkMul(chk, "13880", "13880", "7D784000", 32) # overflow
chkMul(chk, "0", "3", "0", 64)
chkMul(chk, "1", "3", "3", 64)
chkMul(chk, "64", "3", "12C", 64)
chkMul(chk, "1770", "46", "668A0", 64)
chkMul(chk, "13880", "13880", "17D784000", 64)
chkMul(chk, "3B9ACA00", "E8D4A51000", "35C9ADC5DEA00000", 64) # overflow
chkMul(chk, "0", "3", "0", 128)
chkMul(chk, "1", "3", "3", 128)
chkMul(chk, "64", "3", "12C", 128)
chkMul(chk, "1770", "46", "668A0", 128)
chkMul(chk, "13880", "13880", "17D784000", 128)
chkMul(chk, "3B9ACA00", "E8D4A51000", "3635C9ADC5DEA00000", 128)
chkMul(chk, "25295F0D1", "10", "25295F0D10", 128)
chkMul(chk, "123456789ABCDEF00", "123456789ABCDEF00", "4b66dc33f6acdca5e20890f2a5210000", 128) # overflow
chkMul(chk, "123456789ABCDEF00", "123456789ABCDEF00", "14b66dc33f6acdca5e20890f2a5210000", 256)
tst "operator `div`": tst "operator `div`":
chkDiv(chk, "0", "3", "0", 8) chkDiv(chk, "0", "3", "0", 8)
chkDiv(chk, "1", "3", "0", 8) chkDiv(chk, "1", "3", "0", 8)
@ -212,44 +175,10 @@ template testMuldiv(chk, tst: untyped) =
chkDivMod(chk, "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", "27", "6906906906906906906906906906906", "15", 128) chkDivMod(chk, "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", "27", "6906906906906906906906906906906", "15", 128)
static: static:
testMuldiv(ctCheck, ctTest) testdivmod(ctCheck, ctTest)
suite "Wider unsigned int muldiv coverage": suite "Wider unsigned int muldiv coverage":
testMuldiv(check, test) testdivmod(check, test)
suite "Testing unsigned int multiplication implementation":
test "Multiplication with result fitting in low half":
let a = 10000.stuint(64)
let b = 10000.stuint(64)
check: cast[uint64](a*b) == 100_000_000'u64 # need 27-bits
test "Multiplication with result overflowing low half":
let a = 1_000_000.stuint(64)
let b = 1_000_000.stuint(64)
check: cast[uint64](a*b) == 1_000_000_000_000'u64 # need 40 bits
test "Full overflow is handled like native unsigned types":
let a = 1_000_000_000.stuint(64)
let b = 1_000_000_000.stuint(64)
let c = 1_000.stuint(64)
let x = 1_000_000_000'u64
let y = 1_000_000_000'u64
let z = 1_000'u64
let w = x*y*z
#check: cast[uint64](a*b*c) == 1_000_000_000_000_000_000_000'u64 # need 70-bits
check: cast[uint64](a*b*c) == w
test "Nim v1.0.2 32 bit type inference rule changed":
let x = 9975492817.stuint(256)
let y = 16.stuint(256)
check x * y == 159607885072.stuint(256)
suite "Testing unsigned int division and modulo implementation": suite "Testing unsigned int division and modulo implementation":
test "Divmod(100, 13) returns the correct result": test "Divmod(100, 13) returns the correct result":

View File

@ -9,6 +9,7 @@
import ../stint, unittest, stew/byteutils, test_helpers import ../stint, unittest, stew/byteutils, test_helpers
template chkSwapBytes(chk: untyped, bits: int, hex: string) = template chkSwapBytes(chk: untyped, bits: int, hex: string) =
# dumpHex already do the job to swap the output if # dumpHex already do the job to swap the output if
# we use `littleEndian` on both platform # we use `littleEndian` on both platform

88
tests/test_uint_mul.nim Normal file
View File

@ -0,0 +1,88 @@
# Stint
# Copyright 2018 Status Research & Development GmbH
# Licensed under either of
#
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0)
# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT)
#
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import ../stint, unittest, test_helpers
template chkMul(chk: untyped, a, b, c: string, bits: int) =
chk (fromHex(Stuint[bits], a) * fromHex(Stuint[bits], b)) == fromHex(Stuint[bits], c)
template testMul(chk, tst: untyped) =
tst "operator `mul`":
chkMul(chk, "0", "3", "0", 8)
chkMul(chk, "1", "3", "3", 8)
chkMul(chk, "64", "3", "2C", 8) # overflow
chkMul(chk, "0", "3", "0", 16)
chkMul(chk, "1", "3", "3", 16)
chkMul(chk, "64", "3", "12C", 16)
chkMul(chk, "1770", "46", "68A0", 16) # overflow
chkMul(chk, "0", "3", "0", 32)
chkMul(chk, "1", "3", "3", 32)
chkMul(chk, "64", "3", "12C", 32)
chkMul(chk, "1770", "46", "668A0", 32)
chkMul(chk, "13880", "13880", "7D784000", 32) # overflow
chkMul(chk, "0", "3", "0", 64)
chkMul(chk, "1", "3", "3", 64)
chkMul(chk, "64", "3", "12C", 64)
chkMul(chk, "1770", "46", "668A0", 64)
chkMul(chk, "13880", "13880", "17D784000", 64)
chkMul(chk, "3B9ACA00", "E8D4A51000", "35C9ADC5DEA00000", 64) # overflow
chkMul(chk, "0", "3", "0", 128)
chkMul(chk, "1", "3", "3", 128)
chkMul(chk, "64", "3", "12C", 128)
chkMul(chk, "1770", "46", "668A0", 128)
chkMul(chk, "13880", "13880", "17D784000", 128)
chkMul(chk, "3B9ACA00", "E8D4A51000", "3635C9ADC5DEA00000", 128)
chkMul(chk, "25295F0D1", "10", "25295F0D10", 128)
chkMul(chk, "123456789ABCDEF00", "123456789ABCDEF00", "4b66dc33f6acdca5e20890f2a5210000", 128) # overflow
chkMul(chk, "123456789ABCDEF00", "123456789ABCDEF00", "14b66dc33f6acdca5e20890f2a5210000", 256)
static:
testMul(ctCheck, ctTest)
suite "Wider unsigned int muldiv coverage":
testMul(check, test)
suite "Testing unsigned int multiplication implementation":
test "Multiplication with result fitting in low half":
let a = 10000.stuint(64)
let b = 10000.stuint(64)
check: cast[uint64](a*b) == 100_000_000'u64 # need 27-bits
test "Multiplication with result overflowing low half":
let a = 1_000_000.stuint(64)
let b = 1_000_000.stuint(64)
check: cast[uint64](a*b) == 1_000_000_000_000'u64 # need 40 bits
test "Full overflow is handled like native unsigned types":
let a = 1_000_000_000.stuint(64)
let b = 1_000_000_000.stuint(64)
let c = 1_000.stuint(64)
let x = 1_000_000_000'u64
let y = 1_000_000_000'u64
let z = 1_000'u64
let w = x*y*z
#check: cast[uint64](a*b*c) == 1_000_000_000_000_000_000_000'u64 # need 70-bits
check: cast[uint64](a*b*c) == w
test "Nim v1.0.2 32 bit type inference rule changed":
let x = 9975492817.stuint(256)
let y = 16.stuint(256)
check x * y == 159607885072.stuint(256)