From 6480939dcdac96462d814663514ec1315afc163e Mon Sep 17 00:00:00 2001 From: jangko Date: Tue, 20 Jun 2023 14:59:26 +0700 Subject: [PATCH] conversion between big integers --- stint/io.nim | 182 +++++++++++++++++++----------------- stint/private/datatypes.nim | 6 ++ tests/all_tests.nim | 4 +- tests/test_conversion.nim | 94 +++++++++---------- 4 files changed, 148 insertions(+), 138 deletions(-) diff --git a/stint/io.nim b/stint/io.nim index 43c62fe..59408af 100644 --- a/stint/io.nim +++ b/stint/io.nim @@ -16,6 +16,10 @@ import from stew/byteutils import toHex +# Helpers +# -------------------------------------------------------- +{.push raises: [], inline, gcsafe.} + template leastSignificantWord*(a: SomeBigInteger): Word = mixin limbs a.limbs[0] @@ -30,6 +34,18 @@ template signedWordType*(_: type SomeBigInteger): type = template wordType*(_: type SomeBigInteger): type = Word +template hash*(num: StUint|StInt): Hash = + # TODO: + # `hashData` is not particularly efficient. + # Explore better hashing solutions in nim-stew. + hashData(unsafeAddr num, sizeof num) + +{.pop.} + +# Constructors +# -------------------------------------------------------- +{.push raises: [], inline, gcsafe.} + func stuint*[T: SomeInteger](n: T, bits: static[int]): StUint[bits] {.inline.}= ## Converts an integer to an arbitrary precision integer. when sizeof(n) > sizeof(Word): @@ -59,6 +75,12 @@ func to*(a: SomeInteger, T: typedesc[StInt]): T = func to*(a: SomeUnsignedInt, T: typedesc[StUint]): T = stuint(a, result.bits) +{.pop.} + +# Conversions +# -------------------------------------------------------- +{.push raises: [], inline, gcsafe.} + func truncate*(num: StUint, T: typedesc[SomeInteger]): T {.inline.}= ## Extract the int, uint, int8-int64 or uint8-uint64 portion of a multi-precision integer. ## Note that int and uint are 32-bit on 32-bit platform. @@ -78,7 +100,7 @@ func truncate*(num: StInt, T: typedesc[SomeInteger]): T {.inline.}= ## For signed result type, result is undefined if input does not fit in the target type. let n = num.abs when sizeof(T) > sizeof(Word): - result = T(n.leastSignificantWord()) + result = T(n.leastSignificantWord()) else: result = T(n.leastSignificantWord() and Word(T.high)) @@ -104,87 +126,67 @@ func stuint*(a: StUint, bits: static[int]): StUint[bits] {.inline.} = ## unsigned int to unsigned int conversion ## smaller to bigger bits conversion will have the same value ## bigger to smaller bits conversion, the result is truncated - for i in 0 ..< result.len: - result[i] = a[i] + when bits <= a.bits: + for i in 0 ..< result.len: + result[i] = a[i] + else: + for i in 0 ..< a.len: + result[i] = a[i] -# func StUint*(a: StInt, bits: static[int]): StUint[bits] {.inline.} = -# ## signed int to unsigned int conversion -# ## current behavior is cast-like, copying bit pattern -# ## or truncating if input does not fit into destination -# const N = bitsof(x.data) -# when N < bits: -# when N <= 64: -# type T = StUint[N] -# result = StUint(convert[T](a).data, bits) -# else: -# smallToBig(result.data, a.data) -# elif N > bits: -# when bits <= 64: -# result = StUint(x.truncate(type(result.data)), bits) -# else: -# bigToSmall(result.data, a.data) -# else: -# result = convert[type(result)](a) +func stuint*(a: StInt, bits: static[int]): StUint[bits] {.inline.} = + ## signed int to unsigned int conversion + ## bigger to smaller bits conversion, the result is truncated + doAssert(a.isPositive, "Cannot convert negative number to unsigned int") + stuint(a.impl, bits) -# func stint*(a: StInt, bits: static[int]): StInt[bits] {.inline.} = -# ## signed int to signed int conversion -# ## will raise exception if input does not fit into destination -# const N = bitsof(a.data) -# when N < bits: -# when N <= 64: -# result = stint(a.data, bits) -# else: -# if a.isNegative: -# smallToBig(result.data, (-a).data) -# result = -result -# else: -# smallToBig(result.data, a.data) -# elif N > bits: -# template checkNegativeRange() = -# # due to bug #92, we skip negative range check -# when false: -# const dmin = stint((type result).low, N) -# if a < dmin: raise newException(RangeError, "value out of range") +func smallToBig(a: StInt, bits: static[int]): StInt[bits] = + if a.isNegative: + result.impl = stuint(a.neg.impl, bits) + result.negate + else: + result.impl = stuint(a.impl, bits) -# template checkPositiveRange() = -# const dmax = stint((type result).high, N) -# if a > dmax: raise newException(RangeError, "value out of range") +func stint*(a: StInt, bits: static[int]): StInt[bits] {.inline.} = + ## signed int to signed int conversion + ## will raise exception if input does not fit into destination + when a.bits < bits: + if a.isNegative: + result.impl = stuint(a.neg.impl, bits) + result.negate + else: + result.impl = stuint(a, bits) + elif a.bits > bits: + template checkNegativeRange() = + const dmin = smallToBig((type result).low, a.bits) + if a < dmin: raise newException(RangeDefect, "value out of range") -# when bits <= 64: -# if a.isNegative: -# checkNegativeRange() -# result = stint((-a).truncate(type(result.data)), bits) -# result = -result -# else: -# checkPositiveRange() -# result = stint(a.truncate(type(result.data)), bits) -# else: -# if a.isNegative: -# checkNegativeRange() -# bigToSmall(result.data, (-a).data) -# result = -result -# else: -# checkPositiveRange() -# bigToSmall(result.data, a.data) -# else: -# result = a + template checkPositiveRange() = + const dmax = smallToBig((type result).high, a.bits) + if a > dmax: raise newException(RangeDefect, "value out of range") -# func stint*(a: StUint, bits: static[int]): StInt[bits] {.inline.} = -# const N = bitsof(a.data) -# const dmax = StUint((type result).high, N) -# if a > dmax: raise newException(RangeError, "value out of range") -# when N < bits: -# when N <= 64: -# result = stint(a.data, bits) -# else: -# smallToBig(result.data, a.data) -# elif N > bits: -# when bits <= 64: -# result = stint(a.truncate(type(result.data)), bits) -# else: -# bigToSmall(result.data, a.data) -# else: -# result = convert[type(result)](a) + if a.isNegative: + checkNegativeRange() + result.impl = stuint(a.neg.impl, bits) + result.negate + else: + checkPositiveRange() + result.impl = stuint(a, bits) + else: + result = a + +func stint*(a: StUint, bits: static[int]): StInt[bits] {.inline.} = + ## signed int to unsigned int conversion + ## will raise exception if input does not fit into destination + + const dmax = stuint((type result).high, a.bits) + if a > dmax: raise newException(RangeDefect, "value out of range") + result.impl = stuint(a, bits) + +{.pop.} + +# Serializations to/from string +# -------------------------------------------------------- +{.push gcsafe.} func readHexChar(c: char): int8 {.inline.}= ## Converts an hex char to an int @@ -229,7 +231,9 @@ func readDecChar(c: range['0'..'9']): int {.inline.}= # specialization without branching for base <= 10. ord(c) - ord('0') -func parse*[bits: static[int]](input: string, T: typedesc[StUint[bits]], radix: static[uint8] = 10): T = +func parse*[bits: static[int]](input: string, + T: typedesc[StUint[bits]], + radix: static[uint8] = 10): T = ## Parse a string and store the result in a Stint[bits] or StUint[bits]. static: doAssert (radix >= 2) and radix <= 16, "Only base from 2..16 are supported" @@ -250,7 +254,9 @@ func parse*[bits: static[int]](input: string, T: typedesc[StUint[bits]], radix: result = result * base + input[curr].readHexChar.stuint(bits) nextNonBlank(curr, input) -func parse*[bits: static[int]](input: string, T: typedesc[StInt[bits]], radix: static[int8] = 10): T = +func parse*[bits: static[int]](input: string, + T: typedesc[StInt[bits]], + radix: static[int8] = 10): T = ## Parse a string and store the result in a Stint[bits] or StUint[bits]. static: doAssert (radix >= 2) and radix <= 16, "Only base from 2..16 are supported" @@ -285,11 +291,11 @@ func parse*[bits: static[int]](input: string, T: typedesc[StInt[bits]], radix: s if isNeg: result.negate -func fromHex*(T: typedesc[StUint|StInt], s: string): T {.inline.} = +func fromHex*(T: typedesc[StUint|StInt], s: string): T = ## Convert an hex string to the corresponding unsigned integer parse(s, type result, radix = 16) -func hexToUint*[bits: static[int]](hexString: string): StUint[bits] {.inline.} = +func hexToUint*[bits: static[int]](hexString: string): StUint[bits] = ## Convert an hex string to the corresponding unsigned integer parse(hexString, type result, radix = 16) @@ -354,6 +360,12 @@ func dumpHex*(a: StInt or StUint, order: static[Endianness] = bigEndian): string let bytes = a.toBytes(order) result = bytes.toHex() +{.pop.} + +# Serializations to/from bytes +# -------------------------------------------------------- +{.push raises: [], inline, noinit, gcsafe.} + export fromBytes, toBytes func readUintBE*[bits: static[int]](ba: openArray[byte]): StUint[bits] {.noinit, inline.}= @@ -372,12 +384,6 @@ func toByteArrayBE*[bits: static[int]](n: StUint[bits]): array[bits div 8, byte] ## - a big-endian array of the same size result = n.toBytesBE() -template hash*(num: StUint|StInt): Hash = - # TODO: - # `hashData` is not particularly efficient. - # Explore better hashing solutions in nim-stew. - hashData(unsafeAddr num, sizeof num) - func fromBytesBE*(T: type StUint, ba: openArray[byte], allowPadding: static[bool] = true): T {.noinit, inline.}= result = readUintBE[T.bits](ba) #when allowPadding: @@ -385,3 +391,5 @@ func fromBytesBE*(T: type StUint, ba: openArray[byte], allowPadding: static[bool template initFromBytesBE*(x: var StUint, ba: openArray[byte], allowPadding: static[bool] = true) = x = fromBytesBE(type x, ba, allowPadding) + +{.pop.} diff --git a/stint/private/datatypes.nim b/stint/private/datatypes.nim index ff20cfc..fcbaa80 100644 --- a/stint/private/datatypes.nim +++ b/stint/private/datatypes.nim @@ -72,6 +72,12 @@ template `[]`*(a: StUint, i: SomeInteger or BackwardsIndex): Word = template `[]=`*(a: var StUint, i: SomeInteger or BackwardsIndex, val: Word) = a.limbs[i] = val +template len*(a: StInt): int = + a.impl.limbs.len + +template len*(a: StUint): int = + a.limbs.len + # Bithacks # -------------------------------------------------------- diff --git a/tests/all_tests.nim b/tests/all_tests.nim index 2869ae8..9bf1c12 100644 --- a/tests/all_tests.nim +++ b/tests/all_tests.nim @@ -33,5 +33,5 @@ import import - test_io#, - #test_conversion + test_io, + test_conversion diff --git a/tests/test_conversion.nim b/tests/test_conversion.nim index 723c243..69b8cc2 100644 --- a/tests/test_conversion.nim +++ b/tests/test_conversion.nim @@ -26,7 +26,7 @@ template chkStuintToStuint(chk: untyped, N, bits: static[int]) = chk $y == $yy chk $z == $zz -template chkStintToStuint(chk: untyped, N, bits: static[int]) = +template chkStintToStuint(chk, handleErr: untyped, N, bits: static[int]) = block: let w = StInt[N].low let x = StInt[N].high @@ -34,20 +34,20 @@ template chkStintToStuint(chk: untyped, N, bits: static[int]) = let z = stint(1, N) let v = stint(-1, N) - let ww = stuint(w, bits) + handleErr AssertionDefect: + discard stuint(w, bits) + let xx = stuint(x, bits) let yy = stuint(y, bits) let zz = stuint(z, bits) - let vv = stuint(v, bits) + + handleErr AssertionDefect: + discard stuint(v, bits) when N <= bits: chk $x == $xx - chk w.toHex == ww.toHex - chk v.toHex == vv.toHex else: - chk ww == stuint(0, bits) chk $xx == $(StUint[bits].high) - chk $vv == $(StUint[bits].high) chk $y == $yy chk $z == $zz @@ -76,7 +76,7 @@ template chkStintToStint(chk: untyped, N, bits: static[int]) = chk $z == $zz chk $v == $vv -template chkStuintToStint(chk: untyped, N, bits: static[int]) = +template chkStuintToStint(chk, handleErr: untyped, N, bits: static[int]) = block: let y = stuint(0, N) let z = stuint(1, N) @@ -86,12 +86,8 @@ template chkStuintToStint(chk: untyped, N, bits: static[int]) = let zz = stint(z, bits) when bits <= N: - when nimvm: - # expect(...) cannot run in Nim VM - discard - else: - expect(ValueError): - discard stint(v, bits) + handleErr RangeDefect: + discard stint(v, bits) else: let vv = stint(v, bits) chk v.toHex == vv.toHex @@ -99,7 +95,7 @@ template chkStuintToStint(chk: untyped, N, bits: static[int]) = chk $y == $yy chk $z == $zz -template testConversion(chk, tst: untyped) = +template testConversion(chk, tst, handleErr: untyped) = tst "stuint to stuint": chkStuintToStuint(chk, 64, 64) chkStuintToStuint(chk, 64, 128) @@ -122,25 +118,25 @@ template testConversion(chk, tst: untyped) = chkStuintToStuint(chk, 512, 512) tst "stint to stuint": - chkStintToStuint(chk, 64, 64) - chkStintToStuint(chk, 64, 128) - chkStintToStuint(chk, 64, 256) - chkStintToStuint(chk, 64, 512) + chkStintToStuint(chk, handleErr, 64, 64) + chkStintToStuint(chk, handleErr, 64, 128) + chkStintToStuint(chk, handleErr, 64, 256) + chkStintToStuint(chk, handleErr, 64, 512) - chkStintToStuint(chk, 128, 64) - chkStintToStuint(chk, 128, 128) - chkStintToStuint(chk, 128, 256) - chkStintToStuint(chk, 128, 512) + chkStintToStuint(chk, handleErr, 128, 64) + chkStintToStuint(chk, handleErr, 128, 128) + chkStintToStuint(chk, handleErr, 128, 256) + chkStintToStuint(chk, handleErr, 128, 512) - chkStintToStuint(chk, 256, 64) - chkStintToStuint(chk, 256, 128) - chkStintToStuint(chk, 256, 256) - chkStintToStuint(chk, 256, 512) + chkStintToStuint(chk, handleErr, 256, 64) + chkStintToStuint(chk, handleErr, 256, 128) + chkStintToStuint(chk, handleErr, 256, 256) + chkStintToStuint(chk, handleErr, 256, 512) - chkStintToStuint(chk, 512, 64) - chkStintToStuint(chk, 512, 128) - chkStintToStuint(chk, 512, 256) - chkStintToStuint(chk, 512, 512) + chkStintToStuint(chk, handleErr, 512, 64) + chkStintToStuint(chk, handleErr, 512, 128) + chkStintToStuint(chk, handleErr, 512, 256) + chkStintToStuint(chk, handleErr, 512, 512) tst "stint to stint": chkStintToStint(chk, 64, 64) @@ -164,33 +160,33 @@ template testConversion(chk, tst: untyped) = chkStintToStint(chk, 512, 512) tst "stuint to stint": - chkStuintToStint(chk, 64, 64) - chkStuintToStint(chk, 64, 128) - chkStuintToStint(chk, 64, 256) - chkStuintToStint(chk, 64, 512) + chkStuintToStint(chk, handleErr, 64, 64) + chkStuintToStint(chk, handleErr, 64, 128) + chkStuintToStint(chk, handleErr, 64, 256) + chkStuintToStint(chk, handleErr, 64, 512) - chkStuintToStint(chk, 128, 64) - chkStuintToStint(chk, 128, 128) - chkStuintToStint(chk, 128, 256) - chkStuintToStint(chk, 128, 512) + chkStuintToStint(chk, handleErr, 128, 64) + chkStuintToStint(chk, handleErr, 128, 128) + chkStuintToStint(chk, handleErr, 128, 256) + chkStuintToStint(chk, handleErr, 128, 512) - chkStuintToStint(chk, 256, 64) - chkStuintToStint(chk, 256, 128) - chkStuintToStint(chk, 256, 256) - chkStuintToStint(chk, 256, 512) + chkStuintToStint(chk, handleErr, 256, 64) + chkStuintToStint(chk, handleErr, 256, 128) + chkStuintToStint(chk, handleErr, 256, 256) + chkStuintToStint(chk, handleErr, 256, 512) - chkStuintToStint(chk, 512, 64) - chkStuintToStint(chk, 512, 128) - chkStuintToStint(chk, 512, 256) - chkStuintToStint(chk, 512, 512) + chkStuintToStint(chk, handleErr, 512, 64) + chkStuintToStint(chk, handleErr, 512, 128) + chkStuintToStint(chk, handleErr, 512, 256) + chkStuintToStint(chk, handleErr, 512, 512) static: - testConversion(ctCheck, ctTest) + testConversion(ctCheck, ctTest, ctExpect) proc main() = # Nim GC protests we are using too much global variables # so put it in a proc suite "Testing conversion between big integers": - testConversion(check, test) + testConversion(check, test, expect) main()