From ccf87daac1eef15238ff3d6d2edb138e22180d19 Mon Sep 17 00:00:00 2001 From: andri lim Date: Sat, 11 May 2019 20:44:41 +0700 Subject: [PATCH] implement arithmetic right shift (#76) * implement arithmetic right shift * workaround Nim VM 'cast' limitation * fix high(stint) bug * fix compile time bit shift bug * add test for compile time shift and high(stint) * add tests against ttmath --- stint/int_public.nim | 10 +- stint/private/int_bitwise_ops.nim | 71 +++++++++++ stint/private/int_highlow.nim | 6 +- tests/all_tests.nim | 3 +- tests/property_based_uint256.nim | 194 +++++++++++++++--------------- tests/test_int_bitwise.nim | 127 +++++++++++++++++++ tests/ttmath_compat.nim | 4 +- 7 files changed, 308 insertions(+), 107 deletions(-) create mode 100644 tests/test_int_bitwise.nim diff --git a/stint/int_public.nim b/stint/int_public.nim index 45ca208..4dc7bfc 100644 --- a/stint/int_public.nim +++ b/stint/int_public.nim @@ -78,10 +78,12 @@ make_unary(`not`, Stint) make_binary(`or`, Stint) make_binary(`and`, Stint) make_binary(`xor`, Stint) -# proc `shr`*(x: Stint, y: SomeInteger): Stint {.inline, noSideEffect.} = -# result.data = x.data shr y -# proc `shl`*(x: Stint, y: SomeInteger): Stint {.inline, noSideEffect.} = -# result.data = x.data shl y +func `shr`*(x: Stint, y: SomeInteger): Stint {.inline.} = + result.data = x.data shr y +func `shl`*(x: Stint, y: SomeInteger): Stint {.inline.} = + result.data = x.data shl y +func ashr*(x: Stint, y: SomeInteger): Stint {.inline.} = + result.data = ashr(x.data, y) import ./private/int_highlow diff --git a/stint/private/int_bitwise_ops.nim b/stint/private/int_bitwise_ops.nim index 33fe27f..0c0b15d 100644 --- a/stint/private/int_bitwise_ops.nim +++ b/stint/private/int_bitwise_ops.nim @@ -24,3 +24,74 @@ func `and`*(x, y: IntImpl): IntImpl {.inline.}= func `xor`*(x, y: IntImpl): IntImpl {.inline.}= ## `Bitwise xor` of numbers x and y applyHiLo(x, y, `xor`) + +func `shr`*(x: IntImpl, y: SomeInteger): IntImpl {.inline.} + # Forward declaration + +func convertImpl[T: SomeInteger](x: SomeInteger): T {.compileTime.} = + cast[T](x) + +func convertImpl[T: IntImpl|UintImpl](x: IntImpl|UintImpl): T {.compileTime.} = + result.hi = convertImpl[type(result.hi)](x.hi) + result.lo = x.lo + +template convert[T](x: UintImpl|IntImpl|SomeInteger): T = + when nimvm: + # this is a workaround Nim VM inability to cast + # something non integer + convertImpl[T](x) + else: + cast[T](x) + +func `shl`*(x: IntImpl, y: SomeInteger): IntImpl {.inline.}= + ## Compute the `shift left` operation of x and y + # Note: inlining this poses codegen/aliasing issue when doing `x = x shl 1` + + # TODO: would it be better to reimplement this with words iteration? + const halfSize: type(y) = bitsof(x) div 2 + type HiType = type(result.hi) + + if y == 0: + return x + elif y == halfSize: + result.hi = convert[HiType](x.lo) + elif y < halfSize: + result.hi = (x.hi shl y) or convert[HiType](x.lo shr (halfSize - y)) + result.lo = x.lo shl y + else: + result.hi = convert[HiType](x.lo shl (y - halfSize)) + +func `shr`*(x: IntImpl, y: SomeInteger): IntImpl {.inline.}= + ## Compute the `shift right` operation of x and y + ## Similar to C standard, result is undefined if y is bigger + ## than the number of bits in x. + const halfSize: type(y) = bitsof(x) div 2 + type LoType = type(result.lo) + + if y == 0: + return x + elif y == halfSize: + result.lo = convert[LoType](x.hi) + elif y < halfSize: + result.lo = (x.lo shr y) or convert[LoType](x.hi shl (halfSize - y)) + result.hi = x.hi shr y + else: + result.lo = convert[LoType](x.hi shr (y - halfSize)) + +func ashr*(x: IntImpl, y: SomeInteger): IntImpl {.inline.}= + ## Compute the `arithmetic shift right` operation of x and y + ## Similar to C standard, result is undefined if y is bigger + ## than the number of bits in x. + const halfSize: type(y) = bitsof(x) div 2 + type LoType = type(result.lo) + if y == 0: + return x + elif y == halfSize: + result.lo = convert[LoType](x.hi) + result.hi = ashr(x.hi, halfSize-1) + elif y < halfSize: + result.lo = (x.lo shr y) or convert[LoType](x.hi shl (halfSize - y)) + result.hi = ashr(x.hi, y) + else: + result.lo = convert[LoType](ashr(x.hi, (y - halfSize))) + result.hi = ashr(x.hi, halfSize-1) diff --git a/stint/private/int_highlow.nim b/stint/private/int_highlow.nim index dcb7cfd..1337065 100644 --- a/stint/private/int_highlow.nim +++ b/stint/private/int_highlow.nim @@ -7,11 +7,11 @@ # # at your option. This file may not be copied, modified, or distributed except according to those terms. -import ./datatypes, ./int_bitwise_ops, ./initialization, ./uint_highlow +import ./datatypes, ./int_bitwise_ops, ./initialization, ./uint_highlow, typetraits # XXX There's some Araq reason why this isn't part of the std lib.. -func high(_: typedesc[SomeUnsignedInt]): SomeUnsignedInt = - not SomeUnsignedInt(0'u8) +func high(T: typedesc[SomeUnsignedInt]): T = + not T(0) func high*[T, T2](_: typedesc[IntImpl[T, T2]]): IntImpl[T, T2] {.inline.}= # The highest signed int has representation diff --git a/tests/all_tests.nim b/tests/all_tests.nim index 073c4ec..32ebcc7 100644 --- a/tests/all_tests.nim +++ b/tests/all_tests.nim @@ -20,6 +20,7 @@ import test_int_endianness, test_int_comparison, test_int_addsub, test_int_muldiv, - test_int_boundchecks + test_int_boundchecks, + test_int_bitwise import test_io diff --git a/tests/property_based_uint256.nim b/tests/property_based_uint256.nim index 274ae82..d722a06 100644 --- a/tests/property_based_uint256.nim +++ b/tests/property_based_uint256.nim @@ -25,16 +25,16 @@ suite "Property-based testing (testing with random inputs) of Uint256": else: echo "(StUint[64] = uint64)" - let hi = 1'u shl (sizeof(uint)*7) + let hi = 1'u shl (sizeof(uint64)*7) - quicktest "`or`", itercount do(x0: uint(min=0, max=hi), - x1: uint(min=0, max=hi), - x2: uint(min=0, max=hi), - x3: uint(min=0, max=hi), - y0: uint(min=0, max=hi), - y1: uint(min=0, max=hi), - y2: uint(min=0, max=hi), - y3: uint(min=0, max=hi)): + quicktest "`or`", itercount do(x0: uint64(min=0, max=hi), + x1: uint64(min=0, max=hi), + x2: uint64(min=0, max=hi), + x3: uint64(min=0, max=hi), + y0: uint64(min=0, max=hi), + y1: uint64(min=0, max=hi), + y2: uint64(min=0, max=hi), + y3: uint64(min=0, max=hi)): let x = [x0, x1, x2, x3] @@ -51,14 +51,14 @@ suite "Property-based testing (testing with random inputs) of Uint256": check ttm_z.asSt == mp_z - quicktest "`and`", itercount do(x0: uint(min=0, max=hi), - x1: uint(min=0, max=hi), - x2: uint(min=0, max=hi), - x3: uint(min=0, max=hi), - y0: uint(min=0, max=hi), - y1: uint(min=0, max=hi), - y2: uint(min=0, max=hi), - y3: uint(min=0, max=hi)): + quicktest "`and`", itercount do(x0: uint64(min=0, max=hi), + x1: uint64(min=0, max=hi), + x2: uint64(min=0, max=hi), + x3: uint64(min=0, max=hi), + y0: uint64(min=0, max=hi), + y1: uint64(min=0, max=hi), + y2: uint64(min=0, max=hi), + y3: uint64(min=0, max=hi)): let x = [x0, x1, x2, x3] @@ -75,14 +75,14 @@ suite "Property-based testing (testing with random inputs) of Uint256": check ttm_z.asSt == mp_z - quicktest "`xor`", itercount do(x0: uint(min=0, max=hi), - x1: uint(min=0, max=hi), - x2: uint(min=0, max=hi), - x3: uint(min=0, max=hi), - y0: uint(min=0, max=hi), - y1: uint(min=0, max=hi), - y2: uint(min=0, max=hi), - y3: uint(min=0, max=hi)): + quicktest "`xor`", itercount do(x0: uint64(min=0, max=hi), + x1: uint64(min=0, max=hi), + x2: uint64(min=0, max=hi), + x3: uint64(min=0, max=hi), + y0: uint64(min=0, max=hi), + y1: uint64(min=0, max=hi), + y2: uint64(min=0, max=hi), + y3: uint64(min=0, max=hi)): let x = [x0, x1, x2, x3] @@ -100,10 +100,10 @@ suite "Property-based testing (testing with random inputs) of Uint256": check ttm_z.asSt == mp_z # Not defined for ttmath - # quicktest "`not`", itercount do(x0: uint(min=0, max=hi), - # x1: uint(min=0, max=hi), - # x2: uint(min=0, max=hi), - # x3: uint(min=0, max=hi): + # quicktest "`not`", itercount do(x0: uint64(min=0, max=hi), + # x1: uint64(min=0, max=hi), + # x2: uint64(min=0, max=hi), + # x3: uint64(min=0, max=hi): # let # x = [x0, x1, x2, x3] @@ -118,14 +118,14 @@ suite "Property-based testing (testing with random inputs) of Uint256": # check(cast[array[4, uint64]](ttm_z) == cast[array[4, uint64]](mp_z)) - quicktest "`<`", itercount do(x0: uint(min=0, max=hi), - x1: uint(min=0, max=hi), - x2: uint(min=0, max=hi), - x3: uint(min=0, max=hi), - y0: uint(min=0, max=hi), - y1: uint(min=0, max=hi), - y2: uint(min=0, max=hi), - y3: uint(min=0, max=hi)): + quicktest "`<`", itercount do(x0: uint64(min=0, max=hi), + x1: uint64(min=0, max=hi), + x2: uint64(min=0, max=hi), + x3: uint64(min=0, max=hi), + y0: uint64(min=0, max=hi), + y1: uint64(min=0, max=hi), + y2: uint64(min=0, max=hi), + y3: uint64(min=0, max=hi)): let x = [x0, x1, x2, x3] @@ -143,14 +143,14 @@ suite "Property-based testing (testing with random inputs) of Uint256": check(ttm_z == mp_z) - quicktest "`<=`", itercount do(x0: uint(min=0, max=hi), - x1: uint(min=0, max=hi), - x2: uint(min=0, max=hi), - x3: uint(min=0, max=hi), - y0: uint(min=0, max=hi), - y1: uint(min=0, max=hi), - y2: uint(min=0, max=hi), - y3: uint(min=0, max=hi)): + quicktest "`<=`", itercount do(x0: uint64(min=0, max=hi), + x1: uint64(min=0, max=hi), + x2: uint64(min=0, max=hi), + x3: uint64(min=0, max=hi), + y0: uint64(min=0, max=hi), + y1: uint64(min=0, max=hi), + y2: uint64(min=0, max=hi), + y3: uint64(min=0, max=hi)): let x = [x0, x1, x2, x3] @@ -167,14 +167,14 @@ suite "Property-based testing (testing with random inputs) of Uint256": check(ttm_z == mp_z) - quicktest "`+`", itercount do(x0: uint(min=0, max=hi), - x1: uint(min=0, max=hi), - x2: uint(min=0, max=hi), - x3: uint(min=0, max=hi), - y0: uint(min=0, max=hi), - y1: uint(min=0, max=hi), - y2: uint(min=0, max=hi), - y3: uint(min=0, max=hi)): + quicktest "`+`", itercount do(x0: uint64(min=0, max=hi), + x1: uint64(min=0, max=hi), + x2: uint64(min=0, max=hi), + x3: uint64(min=0, max=hi), + y0: uint64(min=0, max=hi), + y1: uint64(min=0, max=hi), + y2: uint64(min=0, max=hi), + y3: uint64(min=0, max=hi)): let x = [x0, x1, x2, x3] @@ -191,14 +191,14 @@ suite "Property-based testing (testing with random inputs) of Uint256": check ttm_z.asSt == mp_z - quicktest "`-`", itercount do(x0: uint(min=0, max=hi), - x1: uint(min=0, max=hi), - x2: uint(min=0, max=hi), - x3: uint(min=0, max=hi), - y0: uint(min=0, max=hi), - y1: uint(min=0, max=hi), - y2: uint(min=0, max=hi), - y3: uint(min=0, max=hi)): + quicktest "`-`", itercount do(x0: uint64(min=0, max=hi), + x1: uint64(min=0, max=hi), + x2: uint64(min=0, max=hi), + x3: uint64(min=0, max=hi), + y0: uint64(min=0, max=hi), + y1: uint64(min=0, max=hi), + y2: uint64(min=0, max=hi), + y3: uint64(min=0, max=hi)): let x = [x0, x1, x2, x3] @@ -215,14 +215,14 @@ suite "Property-based testing (testing with random inputs) of Uint256": check ttm_z.asSt == mp_z - quicktest "`*`", itercount do(x0: uint(min=0, max=hi), - x1: uint(min=0, max=hi), - x2: uint(min=0, max=hi), - x3: uint(min=0, max=hi), - y0: uint(min=0, max=hi), - y1: uint(min=0, max=hi), - y2: uint(min=0, max=hi), - y3: uint(min=0, max=hi)): + quicktest "`*`", itercount do(x0: uint64(min=0, max=hi), + x1: uint64(min=0, max=hi), + x2: uint64(min=0, max=hi), + x3: uint64(min=0, max=hi), + y0: uint64(min=0, max=hi), + y1: uint64(min=0, max=hi), + y2: uint64(min=0, max=hi), + y3: uint64(min=0, max=hi)): let x = [x0, x1, x2, x3] @@ -239,10 +239,10 @@ suite "Property-based testing (testing with random inputs) of Uint256": check ttm_z.asSt == mp_z - quicktest "`shl`", itercount do(x0: uint(min=0, max=hi), - x1: uint(min=0, max=hi), - x2: uint(min=0, max=hi), - x3: uint(min=0, max=hi), + quicktest "`shl`", itercount do(x0: uint64(min=0, max=hi), + x1: uint64(min=0, max=hi), + x2: uint64(min=0, max=hi), + x3: uint64(min=0, max=hi), y: int(min = 0, max=(255))): let @@ -257,10 +257,10 @@ suite "Property-based testing (testing with random inputs) of Uint256": check ttm_z.asSt == mp_z - quicktest "`shr`", itercount do(x0: uint(min=0, max=hi), - x1: uint(min=0, max=hi), - x2: uint(min=0, max=hi), - x3: uint(min=0, max=hi), + quicktest "`shr`", itercount do(x0: uint64(min=0, max=hi), + x1: uint64(min=0, max=hi), + x2: uint64(min=0, max=hi), + x3: uint64(min=0, max=hi), y: int(min = 0, max=(255))): let @@ -275,14 +275,14 @@ suite "Property-based testing (testing with random inputs) of Uint256": check ttm_z.asSt == mp_z - quicktest "`mod`", itercount do(x0: uint(min=0, max=hi), - x1: uint(min=0, max=hi), - x2: uint(min=0, max=hi), - x3: uint(min=0, max=hi), - y0: uint(min=0, max=hi), - y1: uint(min=0, max=hi), - y2: uint(min=0, max=hi), - y3: uint(min=0, max=hi)): + quicktest "`mod`", itercount do(x0: uint64(min=0, max=hi), + x1: uint64(min=0, max=hi), + x2: uint64(min=0, max=hi), + x3: uint64(min=0, max=hi), + y0: uint64(min=0, max=hi), + y1: uint64(min=0, max=hi), + y2: uint64(min=0, max=hi), + y3: uint64(min=0, max=hi)): let x = [x0, x1, x2, x3] @@ -299,14 +299,14 @@ suite "Property-based testing (testing with random inputs) of Uint256": check ttm_z.asSt == mp_z - quicktest "`div`", itercount do(x0: uint(min=0, max=hi), - x1: uint(min=0, max=hi), - x2: uint(min=0, max=hi), - x3: uint(min=0, max=hi), - y0: uint(min=0, max=hi), - y1: uint(min=0, max=hi), - y2: uint(min=0, max=hi), - y3: uint(min=0, max=hi)): + quicktest "`div`", itercount do(x0: uint64(min=0, max=hi), + x1: uint64(min=0, max=hi), + x2: uint64(min=0, max=hi), + x3: uint64(min=0, max=hi), + y0: uint64(min=0, max=hi), + y1: uint64(min=0, max=hi), + y2: uint64(min=0, max=hi), + y3: uint64(min=0, max=hi)): let x = [x0, x1, x2, x3] @@ -323,10 +323,10 @@ suite "Property-based testing (testing with random inputs) of Uint256": check ttm_z.asSt == mp_z - quicktest "pow", itercount do(x0: uint(min=0, max=hi), - x1: uint(min=0, max=hi), - x2: uint(min=0, max=hi), - x3: uint(min=0, max=hi), + quicktest "pow", itercount do(x0: uint64(min=0, max=hi), + x1: uint64(min=0, max=hi), + x2: uint64(min=0, max=hi), + x3: uint64(min=0, max=hi), y : int(min=0, max=high(int))): let diff --git a/tests/test_int_bitwise.nim b/tests/test_int_bitwise.nim new file mode 100644 index 0000000..8660ea0 --- /dev/null +++ b/tests/test_int_bitwise.nim @@ -0,0 +1,127 @@ +# Stint +# Copyright 2018 Status Research & Development GmbH +# Licensed under either of +# +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) +# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) +# +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import ../stint, unittest + +when defined(cpp): + import quicktest, ttmath_compat + +func high(T: typedesc[SomeUnsignedInt]): T = + not T(0) + +suite "Testing signed int bitwise operations": + const + hi = high(int64) + lo = low(int64) + itercount = 1000 + + test "Shift Left": + var y = 1.u256 + for i in 1..255: + let x = 1.i256 shl i + y = y shl 1 + check cast[stint.Uint256](x) == y + + test "Shift Right": + const leftMost = 1.i256 shl 255 + var y = 1.u256 shl 255 + for i in 1..255: + let x = leftMost shr i + y = y shr 1 + check cast[stint.Uint256](x) == y + + test "ashr on positive int": + const leftMost = 1.i256 shl 254 + var y = 1.u256 shl 254 + for i in 1..255: + let x = ashr(leftMost, i) + y = y shr 1 + check x == cast[stint.Int256](y) + + test "ashr on negative int": + const + leftMostU = 1.u256 shl 255 + leftMostI = 1.i256 shl 255 + var y = leftMostU + for i in 1..255: + let x = ashr(leftMostI, i) + y = (y shr 1) or leftMostU + check x == cast[stint.Int256](y) + + test "Compile time shift": + const + # set all bits + x = high(stint.Int256) or (1.i256 shl 255) + y = not 0.i256 + + check x == y + + const + a = (high(stint.Int256) shl 10) shr 10 + b = (high(stint.Uint256) shl 10) shr 10 + c = ashr(high(stint.Int256) shl 10, 10) + + check a == cast[stint.Int256](b) + check c != cast[stint.Int256](b) + check c != a + + when defined(cpp): + quicktest "signed int `shl` vs ttmath", itercount do(x0: int64(min=lo, max=hi), + x1: int64(min=0, max=hi), + x2: int64(min=0, max=hi), + x3: int64(min=0, max=hi), + y: int(min=0, max=(255))): + + let + x = [x0, x1, x2, x3] + + ttm_x = x.asTT + mp_x = cast[stint.Int256](x) + + let + ttm_z = ttm_x shl y + mp_z = mp_x shl y + + check ttm_z.asSt == mp_z + + quicktest "signed int `shr` vs ttmath", itercount do(x0: int64(min=lo, max=hi), + x1: int64(min=0, max=hi), + x2: int64(min=0, max=hi), + x3: int64(min=0, max=hi), + y: int(min=0, max=(255))): + + let + x = [cast[uint64](x0), cast[uint64](x1), cast[uint64](x2), cast[uint64](x3)] + + ttm_x = x.asTT + mp_x = cast[stint.Int256](x) + + let + ttm_z = ttm_x shr y.uint + mp_z = mp_x shr y + + check cast[stint.Int256](ttm_z.asSt) == mp_z + + quicktest "arithmetic shift right vs ttmath", itercount do(x0: int64(min=lo, max=hi), + x1: int64(min=0, max=hi), + x2: int64(min=0, max=hi), + x3: int64(min=0, max=hi), + y: int(min=0, max=(255))): + + let + x = [x0, x1, x2, x3] + + ttm_x = x.asTT + mp_x = cast[stint.Int256](x) + + let + ttm_z = ttm_x shr y # C/CPP usually implement `shr` as `ashr` a.k.a. `sar` + mp_z = ashr(mp_x, y) + + check ttm_z.asSt == mp_z diff --git a/tests/ttmath_compat.nim b/tests/ttmath_compat.nim index 29101dd..eda9fbc 100644 --- a/tests/ttmath_compat.nim +++ b/tests/ttmath_compat.nim @@ -9,11 +9,11 @@ template asSt*(val: Int): auto = type TargetType = StInt[val.NumBits] cast[ptr TargetType](unsafeAddr val)[] -template asTT*[N: static[int]](arr: array[N, uint]): auto = +template asTT*[N: static[int]](arr: array[N, uint64]): auto = type TargetType = UInt[N * 64] cast[ptr TargetType](unsafeAddr arr[0])[] -template asTT*[N: static[int]](arr: array[N, int]): auto = +template asTT*[N: static[int]](arr: array[N, int64]): auto = type TargetType = Int[N * 64] cast[ptr TargetType](unsafeAddr arr[0])[]