implement arithmetic right shift (#76)

* implement arithmetic right shift

* workaround Nim VM 'cast' limitation

* fix high(stint) bug

* fix compile time bit shift bug

* add test for compile time shift and high(stint)

* add tests against ttmath
This commit is contained in:
andri lim 2019-05-11 20:44:41 +07:00 committed by Mamy Ratsimbazafy
parent 9c51f9e7d5
commit ccf87daac1
7 changed files with 308 additions and 107 deletions

View File

@ -78,10 +78,12 @@ make_unary(`not`, Stint)
make_binary(`or`, Stint)
make_binary(`and`, Stint)
make_binary(`xor`, Stint)
# proc `shr`*(x: Stint, y: SomeInteger): Stint {.inline, noSideEffect.} =
# result.data = x.data shr y
# proc `shl`*(x: Stint, y: SomeInteger): Stint {.inline, noSideEffect.} =
# result.data = x.data shl y
func `shr`*(x: Stint, y: SomeInteger): Stint {.inline.} =
result.data = x.data shr y
func `shl`*(x: Stint, y: SomeInteger): Stint {.inline.} =
result.data = x.data shl y
func ashr*(x: Stint, y: SomeInteger): Stint {.inline.} =
result.data = ashr(x.data, y)
import ./private/int_highlow

View File

@ -24,3 +24,74 @@ func `and`*(x, y: IntImpl): IntImpl {.inline.}=
func `xor`*(x, y: IntImpl): IntImpl {.inline.}=
## `Bitwise xor` of numbers x and y
applyHiLo(x, y, `xor`)
func `shr`*(x: IntImpl, y: SomeInteger): IntImpl {.inline.}
# Forward declaration
func convertImpl[T: SomeInteger](x: SomeInteger): T {.compileTime.} =
cast[T](x)
func convertImpl[T: IntImpl|UintImpl](x: IntImpl|UintImpl): T {.compileTime.} =
result.hi = convertImpl[type(result.hi)](x.hi)
result.lo = x.lo
template convert[T](x: UintImpl|IntImpl|SomeInteger): T =
when nimvm:
# this is a workaround Nim VM inability to cast
# something non integer
convertImpl[T](x)
else:
cast[T](x)
func `shl`*(x: IntImpl, y: SomeInteger): IntImpl {.inline.}=
## Compute the `shift left` operation of x and y
# Note: inlining this poses codegen/aliasing issue when doing `x = x shl 1`
# TODO: would it be better to reimplement this with words iteration?
const halfSize: type(y) = bitsof(x) div 2
type HiType = type(result.hi)
if y == 0:
return x
elif y == halfSize:
result.hi = convert[HiType](x.lo)
elif y < halfSize:
result.hi = (x.hi shl y) or convert[HiType](x.lo shr (halfSize - y))
result.lo = x.lo shl y
else:
result.hi = convert[HiType](x.lo shl (y - halfSize))
func `shr`*(x: IntImpl, y: SomeInteger): IntImpl {.inline.}=
## Compute the `shift right` operation of x and y
## Similar to C standard, result is undefined if y is bigger
## than the number of bits in x.
const halfSize: type(y) = bitsof(x) div 2
type LoType = type(result.lo)
if y == 0:
return x
elif y == halfSize:
result.lo = convert[LoType](x.hi)
elif y < halfSize:
result.lo = (x.lo shr y) or convert[LoType](x.hi shl (halfSize - y))
result.hi = x.hi shr y
else:
result.lo = convert[LoType](x.hi shr (y - halfSize))
func ashr*(x: IntImpl, y: SomeInteger): IntImpl {.inline.}=
## Compute the `arithmetic shift right` operation of x and y
## Similar to C standard, result is undefined if y is bigger
## than the number of bits in x.
const halfSize: type(y) = bitsof(x) div 2
type LoType = type(result.lo)
if y == 0:
return x
elif y == halfSize:
result.lo = convert[LoType](x.hi)
result.hi = ashr(x.hi, halfSize-1)
elif y < halfSize:
result.lo = (x.lo shr y) or convert[LoType](x.hi shl (halfSize - y))
result.hi = ashr(x.hi, y)
else:
result.lo = convert[LoType](ashr(x.hi, (y - halfSize)))
result.hi = ashr(x.hi, halfSize-1)

View File

@ -7,11 +7,11 @@
#
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import ./datatypes, ./int_bitwise_ops, ./initialization, ./uint_highlow
import ./datatypes, ./int_bitwise_ops, ./initialization, ./uint_highlow, typetraits
# XXX There's some Araq reason why this isn't part of the std lib..
func high(_: typedesc[SomeUnsignedInt]): SomeUnsignedInt =
not SomeUnsignedInt(0'u8)
func high(T: typedesc[SomeUnsignedInt]): T =
not T(0)
func high*[T, T2](_: typedesc[IntImpl[T, T2]]): IntImpl[T, T2] {.inline.}=
# The highest signed int has representation

View File

@ -20,6 +20,7 @@ import test_int_endianness,
test_int_comparison,
test_int_addsub,
test_int_muldiv,
test_int_boundchecks
test_int_boundchecks,
test_int_bitwise
import test_io

View File

@ -25,16 +25,16 @@ suite "Property-based testing (testing with random inputs) of Uint256":
else:
echo "(StUint[64] = uint64)"
let hi = 1'u shl (sizeof(uint)*7)
let hi = 1'u shl (sizeof(uint64)*7)
quicktest "`or`", itercount do(x0: uint(min=0, max=hi),
x1: uint(min=0, max=hi),
x2: uint(min=0, max=hi),
x3: uint(min=0, max=hi),
y0: uint(min=0, max=hi),
y1: uint(min=0, max=hi),
y2: uint(min=0, max=hi),
y3: uint(min=0, max=hi)):
quicktest "`or`", itercount do(x0: uint64(min=0, max=hi),
x1: uint64(min=0, max=hi),
x2: uint64(min=0, max=hi),
x3: uint64(min=0, max=hi),
y0: uint64(min=0, max=hi),
y1: uint64(min=0, max=hi),
y2: uint64(min=0, max=hi),
y3: uint64(min=0, max=hi)):
let
x = [x0, x1, x2, x3]
@ -51,14 +51,14 @@ suite "Property-based testing (testing with random inputs) of Uint256":
check ttm_z.asSt == mp_z
quicktest "`and`", itercount do(x0: uint(min=0, max=hi),
x1: uint(min=0, max=hi),
x2: uint(min=0, max=hi),
x3: uint(min=0, max=hi),
y0: uint(min=0, max=hi),
y1: uint(min=0, max=hi),
y2: uint(min=0, max=hi),
y3: uint(min=0, max=hi)):
quicktest "`and`", itercount do(x0: uint64(min=0, max=hi),
x1: uint64(min=0, max=hi),
x2: uint64(min=0, max=hi),
x3: uint64(min=0, max=hi),
y0: uint64(min=0, max=hi),
y1: uint64(min=0, max=hi),
y2: uint64(min=0, max=hi),
y3: uint64(min=0, max=hi)):
let
x = [x0, x1, x2, x3]
@ -75,14 +75,14 @@ suite "Property-based testing (testing with random inputs) of Uint256":
check ttm_z.asSt == mp_z
quicktest "`xor`", itercount do(x0: uint(min=0, max=hi),
x1: uint(min=0, max=hi),
x2: uint(min=0, max=hi),
x3: uint(min=0, max=hi),
y0: uint(min=0, max=hi),
y1: uint(min=0, max=hi),
y2: uint(min=0, max=hi),
y3: uint(min=0, max=hi)):
quicktest "`xor`", itercount do(x0: uint64(min=0, max=hi),
x1: uint64(min=0, max=hi),
x2: uint64(min=0, max=hi),
x3: uint64(min=0, max=hi),
y0: uint64(min=0, max=hi),
y1: uint64(min=0, max=hi),
y2: uint64(min=0, max=hi),
y3: uint64(min=0, max=hi)):
let
x = [x0, x1, x2, x3]
@ -100,10 +100,10 @@ suite "Property-based testing (testing with random inputs) of Uint256":
check ttm_z.asSt == mp_z
# Not defined for ttmath
# quicktest "`not`", itercount do(x0: uint(min=0, max=hi),
# x1: uint(min=0, max=hi),
# x2: uint(min=0, max=hi),
# x3: uint(min=0, max=hi):
# quicktest "`not`", itercount do(x0: uint64(min=0, max=hi),
# x1: uint64(min=0, max=hi),
# x2: uint64(min=0, max=hi),
# x3: uint64(min=0, max=hi):
# let
# x = [x0, x1, x2, x3]
@ -118,14 +118,14 @@ suite "Property-based testing (testing with random inputs) of Uint256":
# check(cast[array[4, uint64]](ttm_z) == cast[array[4, uint64]](mp_z))
quicktest "`<`", itercount do(x0: uint(min=0, max=hi),
x1: uint(min=0, max=hi),
x2: uint(min=0, max=hi),
x3: uint(min=0, max=hi),
y0: uint(min=0, max=hi),
y1: uint(min=0, max=hi),
y2: uint(min=0, max=hi),
y3: uint(min=0, max=hi)):
quicktest "`<`", itercount do(x0: uint64(min=0, max=hi),
x1: uint64(min=0, max=hi),
x2: uint64(min=0, max=hi),
x3: uint64(min=0, max=hi),
y0: uint64(min=0, max=hi),
y1: uint64(min=0, max=hi),
y2: uint64(min=0, max=hi),
y3: uint64(min=0, max=hi)):
let
x = [x0, x1, x2, x3]
@ -143,14 +143,14 @@ suite "Property-based testing (testing with random inputs) of Uint256":
check(ttm_z == mp_z)
quicktest "`<=`", itercount do(x0: uint(min=0, max=hi),
x1: uint(min=0, max=hi),
x2: uint(min=0, max=hi),
x3: uint(min=0, max=hi),
y0: uint(min=0, max=hi),
y1: uint(min=0, max=hi),
y2: uint(min=0, max=hi),
y3: uint(min=0, max=hi)):
quicktest "`<=`", itercount do(x0: uint64(min=0, max=hi),
x1: uint64(min=0, max=hi),
x2: uint64(min=0, max=hi),
x3: uint64(min=0, max=hi),
y0: uint64(min=0, max=hi),
y1: uint64(min=0, max=hi),
y2: uint64(min=0, max=hi),
y3: uint64(min=0, max=hi)):
let
x = [x0, x1, x2, x3]
@ -167,14 +167,14 @@ suite "Property-based testing (testing with random inputs) of Uint256":
check(ttm_z == mp_z)
quicktest "`+`", itercount do(x0: uint(min=0, max=hi),
x1: uint(min=0, max=hi),
x2: uint(min=0, max=hi),
x3: uint(min=0, max=hi),
y0: uint(min=0, max=hi),
y1: uint(min=0, max=hi),
y2: uint(min=0, max=hi),
y3: uint(min=0, max=hi)):
quicktest "`+`", itercount do(x0: uint64(min=0, max=hi),
x1: uint64(min=0, max=hi),
x2: uint64(min=0, max=hi),
x3: uint64(min=0, max=hi),
y0: uint64(min=0, max=hi),
y1: uint64(min=0, max=hi),
y2: uint64(min=0, max=hi),
y3: uint64(min=0, max=hi)):
let
x = [x0, x1, x2, x3]
@ -191,14 +191,14 @@ suite "Property-based testing (testing with random inputs) of Uint256":
check ttm_z.asSt == mp_z
quicktest "`-`", itercount do(x0: uint(min=0, max=hi),
x1: uint(min=0, max=hi),
x2: uint(min=0, max=hi),
x3: uint(min=0, max=hi),
y0: uint(min=0, max=hi),
y1: uint(min=0, max=hi),
y2: uint(min=0, max=hi),
y3: uint(min=0, max=hi)):
quicktest "`-`", itercount do(x0: uint64(min=0, max=hi),
x1: uint64(min=0, max=hi),
x2: uint64(min=0, max=hi),
x3: uint64(min=0, max=hi),
y0: uint64(min=0, max=hi),
y1: uint64(min=0, max=hi),
y2: uint64(min=0, max=hi),
y3: uint64(min=0, max=hi)):
let
x = [x0, x1, x2, x3]
@ -215,14 +215,14 @@ suite "Property-based testing (testing with random inputs) of Uint256":
check ttm_z.asSt == mp_z
quicktest "`*`", itercount do(x0: uint(min=0, max=hi),
x1: uint(min=0, max=hi),
x2: uint(min=0, max=hi),
x3: uint(min=0, max=hi),
y0: uint(min=0, max=hi),
y1: uint(min=0, max=hi),
y2: uint(min=0, max=hi),
y3: uint(min=0, max=hi)):
quicktest "`*`", itercount do(x0: uint64(min=0, max=hi),
x1: uint64(min=0, max=hi),
x2: uint64(min=0, max=hi),
x3: uint64(min=0, max=hi),
y0: uint64(min=0, max=hi),
y1: uint64(min=0, max=hi),
y2: uint64(min=0, max=hi),
y3: uint64(min=0, max=hi)):
let
x = [x0, x1, x2, x3]
@ -239,10 +239,10 @@ suite "Property-based testing (testing with random inputs) of Uint256":
check ttm_z.asSt == mp_z
quicktest "`shl`", itercount do(x0: uint(min=0, max=hi),
x1: uint(min=0, max=hi),
x2: uint(min=0, max=hi),
x3: uint(min=0, max=hi),
quicktest "`shl`", itercount do(x0: uint64(min=0, max=hi),
x1: uint64(min=0, max=hi),
x2: uint64(min=0, max=hi),
x3: uint64(min=0, max=hi),
y: int(min = 0, max=(255))):
let
@ -257,10 +257,10 @@ suite "Property-based testing (testing with random inputs) of Uint256":
check ttm_z.asSt == mp_z
quicktest "`shr`", itercount do(x0: uint(min=0, max=hi),
x1: uint(min=0, max=hi),
x2: uint(min=0, max=hi),
x3: uint(min=0, max=hi),
quicktest "`shr`", itercount do(x0: uint64(min=0, max=hi),
x1: uint64(min=0, max=hi),
x2: uint64(min=0, max=hi),
x3: uint64(min=0, max=hi),
y: int(min = 0, max=(255))):
let
@ -275,14 +275,14 @@ suite "Property-based testing (testing with random inputs) of Uint256":
check ttm_z.asSt == mp_z
quicktest "`mod`", itercount do(x0: uint(min=0, max=hi),
x1: uint(min=0, max=hi),
x2: uint(min=0, max=hi),
x3: uint(min=0, max=hi),
y0: uint(min=0, max=hi),
y1: uint(min=0, max=hi),
y2: uint(min=0, max=hi),
y3: uint(min=0, max=hi)):
quicktest "`mod`", itercount do(x0: uint64(min=0, max=hi),
x1: uint64(min=0, max=hi),
x2: uint64(min=0, max=hi),
x3: uint64(min=0, max=hi),
y0: uint64(min=0, max=hi),
y1: uint64(min=0, max=hi),
y2: uint64(min=0, max=hi),
y3: uint64(min=0, max=hi)):
let
x = [x0, x1, x2, x3]
@ -299,14 +299,14 @@ suite "Property-based testing (testing with random inputs) of Uint256":
check ttm_z.asSt == mp_z
quicktest "`div`", itercount do(x0: uint(min=0, max=hi),
x1: uint(min=0, max=hi),
x2: uint(min=0, max=hi),
x3: uint(min=0, max=hi),
y0: uint(min=0, max=hi),
y1: uint(min=0, max=hi),
y2: uint(min=0, max=hi),
y3: uint(min=0, max=hi)):
quicktest "`div`", itercount do(x0: uint64(min=0, max=hi),
x1: uint64(min=0, max=hi),
x2: uint64(min=0, max=hi),
x3: uint64(min=0, max=hi),
y0: uint64(min=0, max=hi),
y1: uint64(min=0, max=hi),
y2: uint64(min=0, max=hi),
y3: uint64(min=0, max=hi)):
let
x = [x0, x1, x2, x3]
@ -323,10 +323,10 @@ suite "Property-based testing (testing with random inputs) of Uint256":
check ttm_z.asSt == mp_z
quicktest "pow", itercount do(x0: uint(min=0, max=hi),
x1: uint(min=0, max=hi),
x2: uint(min=0, max=hi),
x3: uint(min=0, max=hi),
quicktest "pow", itercount do(x0: uint64(min=0, max=hi),
x1: uint64(min=0, max=hi),
x2: uint64(min=0, max=hi),
x3: uint64(min=0, max=hi),
y : int(min=0, max=high(int))):
let

127
tests/test_int_bitwise.nim Normal file
View File

@ -0,0 +1,127 @@
# Stint
# Copyright 2018 Status Research & Development GmbH
# Licensed under either of
#
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0)
# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT)
#
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import ../stint, unittest
when defined(cpp):
import quicktest, ttmath_compat
func high(T: typedesc[SomeUnsignedInt]): T =
not T(0)
suite "Testing signed int bitwise operations":
const
hi = high(int64)
lo = low(int64)
itercount = 1000
test "Shift Left":
var y = 1.u256
for i in 1..255:
let x = 1.i256 shl i
y = y shl 1
check cast[stint.Uint256](x) == y
test "Shift Right":
const leftMost = 1.i256 shl 255
var y = 1.u256 shl 255
for i in 1..255:
let x = leftMost shr i
y = y shr 1
check cast[stint.Uint256](x) == y
test "ashr on positive int":
const leftMost = 1.i256 shl 254
var y = 1.u256 shl 254
for i in 1..255:
let x = ashr(leftMost, i)
y = y shr 1
check x == cast[stint.Int256](y)
test "ashr on negative int":
const
leftMostU = 1.u256 shl 255
leftMostI = 1.i256 shl 255
var y = leftMostU
for i in 1..255:
let x = ashr(leftMostI, i)
y = (y shr 1) or leftMostU
check x == cast[stint.Int256](y)
test "Compile time shift":
const
# set all bits
x = high(stint.Int256) or (1.i256 shl 255)
y = not 0.i256
check x == y
const
a = (high(stint.Int256) shl 10) shr 10
b = (high(stint.Uint256) shl 10) shr 10
c = ashr(high(stint.Int256) shl 10, 10)
check a == cast[stint.Int256](b)
check c != cast[stint.Int256](b)
check c != a
when defined(cpp):
quicktest "signed int `shl` vs ttmath", itercount do(x0: int64(min=lo, max=hi),
x1: int64(min=0, max=hi),
x2: int64(min=0, max=hi),
x3: int64(min=0, max=hi),
y: int(min=0, max=(255))):
let
x = [x0, x1, x2, x3]
ttm_x = x.asTT
mp_x = cast[stint.Int256](x)
let
ttm_z = ttm_x shl y
mp_z = mp_x shl y
check ttm_z.asSt == mp_z
quicktest "signed int `shr` vs ttmath", itercount do(x0: int64(min=lo, max=hi),
x1: int64(min=0, max=hi),
x2: int64(min=0, max=hi),
x3: int64(min=0, max=hi),
y: int(min=0, max=(255))):
let
x = [cast[uint64](x0), cast[uint64](x1), cast[uint64](x2), cast[uint64](x3)]
ttm_x = x.asTT
mp_x = cast[stint.Int256](x)
let
ttm_z = ttm_x shr y.uint
mp_z = mp_x shr y
check cast[stint.Int256](ttm_z.asSt) == mp_z
quicktest "arithmetic shift right vs ttmath", itercount do(x0: int64(min=lo, max=hi),
x1: int64(min=0, max=hi),
x2: int64(min=0, max=hi),
x3: int64(min=0, max=hi),
y: int(min=0, max=(255))):
let
x = [x0, x1, x2, x3]
ttm_x = x.asTT
mp_x = cast[stint.Int256](x)
let
ttm_z = ttm_x shr y # C/CPP usually implement `shr` as `ashr` a.k.a. `sar`
mp_z = ashr(mp_x, y)
check ttm_z.asSt == mp_z

View File

@ -9,11 +9,11 @@ template asSt*(val: Int): auto =
type TargetType = StInt[val.NumBits]
cast[ptr TargetType](unsafeAddr val)[]
template asTT*[N: static[int]](arr: array[N, uint]): auto =
template asTT*[N: static[int]](arr: array[N, uint64]): auto =
type TargetType = UInt[N * 64]
cast[ptr TargetType](unsafeAddr arr[0])[]
template asTT*[N: static[int]](arr: array[N, int]): auto =
template asTT*[N: static[int]](arr: array[N, int64]): auto =
type TargetType = Int[N * 64]
cast[ptr TargetType](unsafeAddr arr[0])[]