diff --git a/stint/intops.nim b/stint/intops.nim index ba75b61..9751f05 100644 --- a/stint/intops.nim +++ b/stint/intops.nim @@ -130,8 +130,6 @@ func `shr`*(x: SomeBigInteger, y: SomeInteger): SomeBigInteger {.inline.} = result.data = x.data shr y func `shl`*(x: SomeBigInteger, y: SomeInteger): SomeBigInteger {.inline.} = result.data = x.data shl y -func ashr*(x: Stint, y: SomeInteger): Stint {.inline.} = - result.data = ashr(x.data, y) import ./private/[int_highlow, uint_highlow] diff --git a/stint/private/int_bitwise_ops.nim b/stint/private/int_bitwise_ops.nim index fe10870..bf1c5de 100644 --- a/stint/private/int_bitwise_ops.nim +++ b/stint/private/int_bitwise_ops.nim @@ -25,9 +25,6 @@ func `xor`*(x, y: IntImpl): IntImpl {.inline.}= ## `Bitwise xor` of numbers x and y applyHiLo(x, y, `xor`) -func `shr`*(x: IntImpl, y: SomeInteger): IntImpl {.inline.} - # Forward declaration - func convertImpl[T: SomeInteger](x: SomeInteger): T {.compileTime.} = cast[T](x) @@ -56,42 +53,46 @@ func `shl`*(x: IntImpl, y: SomeInteger): IntImpl {.inline.}= elif y == halfSize: result.hi = convert[HiType](x.lo) elif y < halfSize: + # `shr` in this equation uses uint version result.hi = (x.hi shl y) or convert[HiType](x.lo shr (halfSize - y)) result.lo = x.lo shl y else: result.hi = convert[HiType](x.lo shl (y - halfSize)) -func `shr`*(x: IntImpl, y: SomeInteger): IntImpl {.inline.}= - ## Compute the `shift right` operation of x and y - ## Similar to C standard, result is undefined if y is bigger - ## than the number of bits in x. - const halfSize: type(y) = bitsof(x) div 2 - type LoType = type(result.lo) +template createShr(name, operator: untyped) = + template name(x, y: SomeInteger): auto = + operator(x, y) - if y == 0: - return x - elif y == halfSize: - result.lo = convert[LoType](x.hi) - elif y < halfSize: - result.lo = (x.lo shr y) or convert[LoType](x.hi shl (halfSize - y)) - result.hi = x.hi shr y - else: - result.lo = convert[LoType](x.hi shr (y - halfSize)) + func name*(x: IntImpl, y: SomeInteger): IntImpl {.inline.}= + ## Compute the `arithmetic shift right` operation of x and y + ## Similar to C standard, result is undefined if y is bigger + ## than the number of bits in x. + const halfSize: type(y) = bitsof(x) div 2 + type LoType = type(result.lo) + if y == 0: + return x + elif y == halfSize: + result.lo = convert[LoType](x.hi) + result.hi = name(x.hi, halfSize-1) + elif y < halfSize: + result.lo = (x.lo shr y) or convert[LoType](x.hi shl (halfSize - y)) + result.hi = name(x.hi, y) + else: + result.lo = convert[LoType](name(x.hi, (y - halfSize))) + result.hi = name(x.hi, halfSize-1) -func ashr*(x: IntImpl, y: SomeInteger): IntImpl {.inline.}= - ## Compute the `arithmetic shift right` operation of x and y - ## Similar to C standard, result is undefined if y is bigger - ## than the number of bits in x. - const halfSize: type(y) = bitsof(x) div 2 - type LoType = type(result.lo) - if y == 0: - return x - elif y == halfSize: - result.lo = convert[LoType](x.hi) - result.hi = ashr(x.hi, halfSize-1) - elif y < halfSize: - result.lo = (x.lo shr y) or convert[LoType](x.hi shl (halfSize - y)) - result.hi = ashr(x.hi, y) - else: - result.lo = convert[LoType](ashr(x.hi, (y - halfSize))) - result.hi = ashr(x.hi, halfSize-1) +template nimVersionIs(comparator: untyped, major, minor, patch: int): bool = + comparator(NimMajor * 100 + NimMinor * 10 + NimPatch, major * 100 + minor * 10 + patch) + +when nimVersionIs(`>=`, 0, 20, 0): + createShr(shrOfShr, `shr`) +elif nimVersionIs(`<`, 0, 20, 0) and defined(nimAshr): + createShr(shrOfAshr, ashr) +else: + {.error: "arithmetic right shift is not defined for this Nim version".} + +template `shr`*(a, b: typed): untyped = + when nimVersionIs(`>=`, 0, 20, 0): + shrOfShr(a, b) + elif nimVersionIs(`<`, 0, 20, 0) and defined(nimAshr): + shrOfAShr(a, b) diff --git a/tests/test_int_bitwise.nim b/tests/test_int_bitwise.nim index 8660ea0..2600e4e 100644 --- a/tests/test_int_bitwise.nim +++ b/tests/test_int_bitwise.nim @@ -28,29 +28,21 @@ suite "Testing signed int bitwise operations": y = y shl 1 check cast[stint.Uint256](x) == y - test "Shift Right": - const leftMost = 1.i256 shl 255 - var y = 1.u256 shl 255 - for i in 1..255: - let x = leftMost shr i - y = y shr 1 - check cast[stint.Uint256](x) == y - - test "ashr on positive int": + test "Shift Right on positive int": const leftMost = 1.i256 shl 254 var y = 1.u256 shl 254 for i in 1..255: - let x = ashr(leftMost, i) + let x = leftMost shr i y = y shr 1 check x == cast[stint.Int256](y) - test "ashr on negative int": + test "Shift Right on negative int": const leftMostU = 1.u256 shl 255 leftMostI = 1.i256 shl 255 var y = leftMostU for i in 1..255: - let x = ashr(leftMostI, i) + let x = leftMostI shr i y = (y shr 1) or leftMostU check x == cast[stint.Int256](y) @@ -65,11 +57,11 @@ suite "Testing signed int bitwise operations": const a = (high(stint.Int256) shl 10) shr 10 b = (high(stint.Uint256) shl 10) shr 10 - c = ashr(high(stint.Int256) shl 10, 10) + c = (high(stint.Int256) shl 10) shr 10 - check a == cast[stint.Int256](b) + check a != cast[stint.Int256](b) check c != cast[stint.Int256](b) - check c != a + check c == a when defined(cpp): quicktest "signed int `shl` vs ttmath", itercount do(x0: int64(min=lo, max=hi), @@ -90,24 +82,6 @@ suite "Testing signed int bitwise operations": check ttm_z.asSt == mp_z - quicktest "signed int `shr` vs ttmath", itercount do(x0: int64(min=lo, max=hi), - x1: int64(min=0, max=hi), - x2: int64(min=0, max=hi), - x3: int64(min=0, max=hi), - y: int(min=0, max=(255))): - - let - x = [cast[uint64](x0), cast[uint64](x1), cast[uint64](x2), cast[uint64](x3)] - - ttm_x = x.asTT - mp_x = cast[stint.Int256](x) - - let - ttm_z = ttm_x shr y.uint - mp_z = mp_x shr y - - check cast[stint.Int256](ttm_z.asSt) == mp_z - quicktest "arithmetic shift right vs ttmath", itercount do(x0: int64(min=lo, max=hi), x1: int64(min=0, max=hi), x2: int64(min=0, max=hi), @@ -122,6 +96,6 @@ suite "Testing signed int bitwise operations": let ttm_z = ttm_x shr y # C/CPP usually implement `shr` as `ashr` a.k.a. `sar` - mp_z = ashr(mp_x, y) + mp_z = mp_x shr y check ttm_z.asSt == mp_z