conversion between big integers

This commit is contained in:
jangko 2023-06-20 14:59:26 +07:00
parent 4a3f300bd6
commit 6480939dcd
No known key found for this signature in database
GPG Key ID: 31702AE10541E6B9
4 changed files with 148 additions and 138 deletions

View File

@ -16,6 +16,10 @@ import
from stew/byteutils import toHex
# Helpers
# --------------------------------------------------------
{.push raises: [], inline, gcsafe.}
template leastSignificantWord*(a: SomeBigInteger): Word =
mixin limbs
a.limbs[0]
@ -30,6 +34,18 @@ template signedWordType*(_: type SomeBigInteger): type =
template wordType*(_: type SomeBigInteger): type =
Word
template hash*(num: StUint|StInt): Hash =
# TODO:
# `hashData` is not particularly efficient.
# Explore better hashing solutions in nim-stew.
hashData(unsafeAddr num, sizeof num)
{.pop.}
# Constructors
# --------------------------------------------------------
{.push raises: [], inline, gcsafe.}
func stuint*[T: SomeInteger](n: T, bits: static[int]): StUint[bits] {.inline.}=
## Converts an integer to an arbitrary precision integer.
when sizeof(n) > sizeof(Word):
@ -59,6 +75,12 @@ func to*(a: SomeInteger, T: typedesc[StInt]): T =
func to*(a: SomeUnsignedInt, T: typedesc[StUint]): T =
stuint(a, result.bits)
{.pop.}
# Conversions
# --------------------------------------------------------
{.push raises: [], inline, gcsafe.}
func truncate*(num: StUint, T: typedesc[SomeInteger]): T {.inline.}=
## Extract the int, uint, int8-int64 or uint8-uint64 portion of a multi-precision integer.
## Note that int and uint are 32-bit on 32-bit platform.
@ -78,7 +100,7 @@ func truncate*(num: StInt, T: typedesc[SomeInteger]): T {.inline.}=
## For signed result type, result is undefined if input does not fit in the target type.
let n = num.abs
when sizeof(T) > sizeof(Word):
result = T(n.leastSignificantWord())
result = T(n.leastSignificantWord())
else:
result = T(n.leastSignificantWord() and Word(T.high))
@ -104,87 +126,67 @@ func stuint*(a: StUint, bits: static[int]): StUint[bits] {.inline.} =
## unsigned int to unsigned int conversion
## smaller to bigger bits conversion will have the same value
## bigger to smaller bits conversion, the result is truncated
for i in 0 ..< result.len:
result[i] = a[i]
when bits <= a.bits:
for i in 0 ..< result.len:
result[i] = a[i]
else:
for i in 0 ..< a.len:
result[i] = a[i]
# func StUint*(a: StInt, bits: static[int]): StUint[bits] {.inline.} =
# ## signed int to unsigned int conversion
# ## current behavior is cast-like, copying bit pattern
# ## or truncating if input does not fit into destination
# const N = bitsof(x.data)
# when N < bits:
# when N <= 64:
# type T = StUint[N]
# result = StUint(convert[T](a).data, bits)
# else:
# smallToBig(result.data, a.data)
# elif N > bits:
# when bits <= 64:
# result = StUint(x.truncate(type(result.data)), bits)
# else:
# bigToSmall(result.data, a.data)
# else:
# result = convert[type(result)](a)
func stuint*(a: StInt, bits: static[int]): StUint[bits] {.inline.} =
## signed int to unsigned int conversion
## bigger to smaller bits conversion, the result is truncated
doAssert(a.isPositive, "Cannot convert negative number to unsigned int")
stuint(a.impl, bits)
# func stint*(a: StInt, bits: static[int]): StInt[bits] {.inline.} =
# ## signed int to signed int conversion
# ## will raise exception if input does not fit into destination
# const N = bitsof(a.data)
# when N < bits:
# when N <= 64:
# result = stint(a.data, bits)
# else:
# if a.isNegative:
# smallToBig(result.data, (-a).data)
# result = -result
# else:
# smallToBig(result.data, a.data)
# elif N > bits:
# template checkNegativeRange() =
# # due to bug #92, we skip negative range check
# when false:
# const dmin = stint((type result).low, N)
# if a < dmin: raise newException(RangeError, "value out of range")
func smallToBig(a: StInt, bits: static[int]): StInt[bits] =
if a.isNegative:
result.impl = stuint(a.neg.impl, bits)
result.negate
else:
result.impl = stuint(a.impl, bits)
# template checkPositiveRange() =
# const dmax = stint((type result).high, N)
# if a > dmax: raise newException(RangeError, "value out of range")
func stint*(a: StInt, bits: static[int]): StInt[bits] {.inline.} =
## signed int to signed int conversion
## will raise exception if input does not fit into destination
when a.bits < bits:
if a.isNegative:
result.impl = stuint(a.neg.impl, bits)
result.negate
else:
result.impl = stuint(a, bits)
elif a.bits > bits:
template checkNegativeRange() =
const dmin = smallToBig((type result).low, a.bits)
if a < dmin: raise newException(RangeDefect, "value out of range")
# when bits <= 64:
# if a.isNegative:
# checkNegativeRange()
# result = stint((-a).truncate(type(result.data)), bits)
# result = -result
# else:
# checkPositiveRange()
# result = stint(a.truncate(type(result.data)), bits)
# else:
# if a.isNegative:
# checkNegativeRange()
# bigToSmall(result.data, (-a).data)
# result = -result
# else:
# checkPositiveRange()
# bigToSmall(result.data, a.data)
# else:
# result = a
template checkPositiveRange() =
const dmax = smallToBig((type result).high, a.bits)
if a > dmax: raise newException(RangeDefect, "value out of range")
# func stint*(a: StUint, bits: static[int]): StInt[bits] {.inline.} =
# const N = bitsof(a.data)
# const dmax = StUint((type result).high, N)
# if a > dmax: raise newException(RangeError, "value out of range")
# when N < bits:
# when N <= 64:
# result = stint(a.data, bits)
# else:
# smallToBig(result.data, a.data)
# elif N > bits:
# when bits <= 64:
# result = stint(a.truncate(type(result.data)), bits)
# else:
# bigToSmall(result.data, a.data)
# else:
# result = convert[type(result)](a)
if a.isNegative:
checkNegativeRange()
result.impl = stuint(a.neg.impl, bits)
result.negate
else:
checkPositiveRange()
result.impl = stuint(a, bits)
else:
result = a
func stint*(a: StUint, bits: static[int]): StInt[bits] {.inline.} =
## signed int to unsigned int conversion
## will raise exception if input does not fit into destination
const dmax = stuint((type result).high, a.bits)
if a > dmax: raise newException(RangeDefect, "value out of range")
result.impl = stuint(a, bits)
{.pop.}
# Serializations to/from string
# --------------------------------------------------------
{.push gcsafe.}
func readHexChar(c: char): int8 {.inline.}=
## Converts an hex char to an int
@ -229,7 +231,9 @@ func readDecChar(c: range['0'..'9']): int {.inline.}=
# specialization without branching for base <= 10.
ord(c) - ord('0')
func parse*[bits: static[int]](input: string, T: typedesc[StUint[bits]], radix: static[uint8] = 10): T =
func parse*[bits: static[int]](input: string,
T: typedesc[StUint[bits]],
radix: static[uint8] = 10): T =
## Parse a string and store the result in a Stint[bits] or StUint[bits].
static: doAssert (radix >= 2) and radix <= 16, "Only base from 2..16 are supported"
@ -250,7 +254,9 @@ func parse*[bits: static[int]](input: string, T: typedesc[StUint[bits]], radix:
result = result * base + input[curr].readHexChar.stuint(bits)
nextNonBlank(curr, input)
func parse*[bits: static[int]](input: string, T: typedesc[StInt[bits]], radix: static[int8] = 10): T =
func parse*[bits: static[int]](input: string,
T: typedesc[StInt[bits]],
radix: static[int8] = 10): T =
## Parse a string and store the result in a Stint[bits] or StUint[bits].
static: doAssert (radix >= 2) and radix <= 16, "Only base from 2..16 are supported"
@ -285,11 +291,11 @@ func parse*[bits: static[int]](input: string, T: typedesc[StInt[bits]], radix: s
if isNeg:
result.negate
func fromHex*(T: typedesc[StUint|StInt], s: string): T {.inline.} =
func fromHex*(T: typedesc[StUint|StInt], s: string): T =
## Convert an hex string to the corresponding unsigned integer
parse(s, type result, radix = 16)
func hexToUint*[bits: static[int]](hexString: string): StUint[bits] {.inline.} =
func hexToUint*[bits: static[int]](hexString: string): StUint[bits] =
## Convert an hex string to the corresponding unsigned integer
parse(hexString, type result, radix = 16)
@ -354,6 +360,12 @@ func dumpHex*(a: StInt or StUint, order: static[Endianness] = bigEndian): string
let bytes = a.toBytes(order)
result = bytes.toHex()
{.pop.}
# Serializations to/from bytes
# --------------------------------------------------------
{.push raises: [], inline, noinit, gcsafe.}
export fromBytes, toBytes
func readUintBE*[bits: static[int]](ba: openArray[byte]): StUint[bits] {.noinit, inline.}=
@ -372,12 +384,6 @@ func toByteArrayBE*[bits: static[int]](n: StUint[bits]): array[bits div 8, byte]
## - a big-endian array of the same size
result = n.toBytesBE()
template hash*(num: StUint|StInt): Hash =
# TODO:
# `hashData` is not particularly efficient.
# Explore better hashing solutions in nim-stew.
hashData(unsafeAddr num, sizeof num)
func fromBytesBE*(T: type StUint, ba: openArray[byte], allowPadding: static[bool] = true): T {.noinit, inline.}=
result = readUintBE[T.bits](ba)
#when allowPadding:
@ -385,3 +391,5 @@ func fromBytesBE*(T: type StUint, ba: openArray[byte], allowPadding: static[bool
template initFromBytesBE*(x: var StUint, ba: openArray[byte], allowPadding: static[bool] = true) =
x = fromBytesBE(type x, ba, allowPadding)
{.pop.}

View File

@ -72,6 +72,12 @@ template `[]`*(a: StUint, i: SomeInteger or BackwardsIndex): Word =
template `[]=`*(a: var StUint, i: SomeInteger or BackwardsIndex, val: Word) =
a.limbs[i] = val
template len*(a: StInt): int =
a.impl.limbs.len
template len*(a: StUint): int =
a.limbs.len
# Bithacks
# --------------------------------------------------------

View File

@ -33,5 +33,5 @@ import
import
test_io#,
#test_conversion
test_io,
test_conversion

View File

@ -26,7 +26,7 @@ template chkStuintToStuint(chk: untyped, N, bits: static[int]) =
chk $y == $yy
chk $z == $zz
template chkStintToStuint(chk: untyped, N, bits: static[int]) =
template chkStintToStuint(chk, handleErr: untyped, N, bits: static[int]) =
block:
let w = StInt[N].low
let x = StInt[N].high
@ -34,20 +34,20 @@ template chkStintToStuint(chk: untyped, N, bits: static[int]) =
let z = stint(1, N)
let v = stint(-1, N)
let ww = stuint(w, bits)
handleErr AssertionDefect:
discard stuint(w, bits)
let xx = stuint(x, bits)
let yy = stuint(y, bits)
let zz = stuint(z, bits)
let vv = stuint(v, bits)
handleErr AssertionDefect:
discard stuint(v, bits)
when N <= bits:
chk $x == $xx
chk w.toHex == ww.toHex
chk v.toHex == vv.toHex
else:
chk ww == stuint(0, bits)
chk $xx == $(StUint[bits].high)
chk $vv == $(StUint[bits].high)
chk $y == $yy
chk $z == $zz
@ -76,7 +76,7 @@ template chkStintToStint(chk: untyped, N, bits: static[int]) =
chk $z == $zz
chk $v == $vv
template chkStuintToStint(chk: untyped, N, bits: static[int]) =
template chkStuintToStint(chk, handleErr: untyped, N, bits: static[int]) =
block:
let y = stuint(0, N)
let z = stuint(1, N)
@ -86,12 +86,8 @@ template chkStuintToStint(chk: untyped, N, bits: static[int]) =
let zz = stint(z, bits)
when bits <= N:
when nimvm:
# expect(...) cannot run in Nim VM
discard
else:
expect(ValueError):
discard stint(v, bits)
handleErr RangeDefect:
discard stint(v, bits)
else:
let vv = stint(v, bits)
chk v.toHex == vv.toHex
@ -99,7 +95,7 @@ template chkStuintToStint(chk: untyped, N, bits: static[int]) =
chk $y == $yy
chk $z == $zz
template testConversion(chk, tst: untyped) =
template testConversion(chk, tst, handleErr: untyped) =
tst "stuint to stuint":
chkStuintToStuint(chk, 64, 64)
chkStuintToStuint(chk, 64, 128)
@ -122,25 +118,25 @@ template testConversion(chk, tst: untyped) =
chkStuintToStuint(chk, 512, 512)
tst "stint to stuint":
chkStintToStuint(chk, 64, 64)
chkStintToStuint(chk, 64, 128)
chkStintToStuint(chk, 64, 256)
chkStintToStuint(chk, 64, 512)
chkStintToStuint(chk, handleErr, 64, 64)
chkStintToStuint(chk, handleErr, 64, 128)
chkStintToStuint(chk, handleErr, 64, 256)
chkStintToStuint(chk, handleErr, 64, 512)
chkStintToStuint(chk, 128, 64)
chkStintToStuint(chk, 128, 128)
chkStintToStuint(chk, 128, 256)
chkStintToStuint(chk, 128, 512)
chkStintToStuint(chk, handleErr, 128, 64)
chkStintToStuint(chk, handleErr, 128, 128)
chkStintToStuint(chk, handleErr, 128, 256)
chkStintToStuint(chk, handleErr, 128, 512)
chkStintToStuint(chk, 256, 64)
chkStintToStuint(chk, 256, 128)
chkStintToStuint(chk, 256, 256)
chkStintToStuint(chk, 256, 512)
chkStintToStuint(chk, handleErr, 256, 64)
chkStintToStuint(chk, handleErr, 256, 128)
chkStintToStuint(chk, handleErr, 256, 256)
chkStintToStuint(chk, handleErr, 256, 512)
chkStintToStuint(chk, 512, 64)
chkStintToStuint(chk, 512, 128)
chkStintToStuint(chk, 512, 256)
chkStintToStuint(chk, 512, 512)
chkStintToStuint(chk, handleErr, 512, 64)
chkStintToStuint(chk, handleErr, 512, 128)
chkStintToStuint(chk, handleErr, 512, 256)
chkStintToStuint(chk, handleErr, 512, 512)
tst "stint to stint":
chkStintToStint(chk, 64, 64)
@ -164,33 +160,33 @@ template testConversion(chk, tst: untyped) =
chkStintToStint(chk, 512, 512)
tst "stuint to stint":
chkStuintToStint(chk, 64, 64)
chkStuintToStint(chk, 64, 128)
chkStuintToStint(chk, 64, 256)
chkStuintToStint(chk, 64, 512)
chkStuintToStint(chk, handleErr, 64, 64)
chkStuintToStint(chk, handleErr, 64, 128)
chkStuintToStint(chk, handleErr, 64, 256)
chkStuintToStint(chk, handleErr, 64, 512)
chkStuintToStint(chk, 128, 64)
chkStuintToStint(chk, 128, 128)
chkStuintToStint(chk, 128, 256)
chkStuintToStint(chk, 128, 512)
chkStuintToStint(chk, handleErr, 128, 64)
chkStuintToStint(chk, handleErr, 128, 128)
chkStuintToStint(chk, handleErr, 128, 256)
chkStuintToStint(chk, handleErr, 128, 512)
chkStuintToStint(chk, 256, 64)
chkStuintToStint(chk, 256, 128)
chkStuintToStint(chk, 256, 256)
chkStuintToStint(chk, 256, 512)
chkStuintToStint(chk, handleErr, 256, 64)
chkStuintToStint(chk, handleErr, 256, 128)
chkStuintToStint(chk, handleErr, 256, 256)
chkStuintToStint(chk, handleErr, 256, 512)
chkStuintToStint(chk, 512, 64)
chkStuintToStint(chk, 512, 128)
chkStuintToStint(chk, 512, 256)
chkStuintToStint(chk, 512, 512)
chkStuintToStint(chk, handleErr, 512, 64)
chkStuintToStint(chk, handleErr, 512, 128)
chkStuintToStint(chk, handleErr, 512, 256)
chkStuintToStint(chk, handleErr, 512, 512)
static:
testConversion(ctCheck, ctTest)
testConversion(ctCheck, ctTest, ctExpect)
proc main() =
# Nim GC protests we are using too much global variables
# so put it in a proc
suite "Testing conversion between big integers":
testConversion(check, test)
testConversion(check, test, expect)
main()