diff --git a/stint/io.nim b/stint/io.nim index a0ea092..884865b 100644 --- a/stint/io.nim +++ b/stint/io.nim @@ -213,14 +213,14 @@ func skipPrefixes(current_idx: var int, str: string, radix: range[2..16]) {.inli elif str[1] in {'o', 'O'}: doAssert radix == 8, "Parsing mismatch, 0o prefix is only valid for an octal number (base 8)" current_idx = 2 - elif str[1] in {'b', 'B'}: + elif str[1] in {'b', 'B'}: if radix == 2: current_idx = 2 elif radix == 16: # allow something like "0bcdef12345" which is a valid hex current_idx = 0 else: - doAssert false, "Parsing mismatch, 0b prefix is only valid for a binary number (base 2), or hex number" + doAssert false, "Parsing mismatch, 0b prefix is only valid for a binary number (base 2), or hex number" func nextNonBlank(current_idx: var int, s: string) {.inline.} = ## Move the current index, skipping white spaces and "_" characters. @@ -333,7 +333,7 @@ func toString*[bits: static[int]](num: StUint[bits], radix: static[uint8] = 10): reverse(result) -func toString*[bits: static[int]](num: StInt[bits], radix: static[int8] = 10): string = +func toString*[bits: static[int]](num: StInt[bits], radix: static[uint8] = 10): string = ## Convert a Stint or StUint to string. ## In case of negative numbers: ## - they are prefixed with "-" for base 10. @@ -466,26 +466,16 @@ template toBytesBE*[bits: static[int]](n: StInt[bits]): array[bits div 8, byte] {.pop.} -func getRadix(s: static string): uint8 {.compileTime.} = - if s.len <= 2: - return 10 - - # maybe have prefix have prefix - if s[0] != '0': - return 10 +include + ./private/custom_literal - if s[1] == 'b': - return 2 - - if s[1] == 'o': - return 8 - - if s[1] == 'x': - return 16 - func customLiteral*(T: type SomeBigInteger, s: static string): T = when s.len == 0: doAssert(false, "customLiteral cannot accept param with zero length") - + const radix = getRadix(s) + type TT = T + when isOverflow(TT, s, radix): + {.error: "Stint custom literal overlow detected" .} + parse(s, T, radix) diff --git a/stint/private/custom_literal.nim b/stint/private/custom_literal.nim new file mode 100644 index 0000000..cbba852 --- /dev/null +++ b/stint/private/custom_literal.nim @@ -0,0 +1,47 @@ +# Stint +# Copyright 2018-2023 Status Research & Development GmbH +# Licensed under either of +# +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) +# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) +# +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +func getRadix(s: static string): uint8 {.compileTime.} = + if s.len <= 2: + return 10 + + # maybe have prefix have prefix + if s[0] != '0': + return 10 + + if s[1] == 'b': + return 2 + + if s[1] == 'o': + return 8 + + if s[1] == 'x': + return 16 + +func stripPrefix(s: string): string {.compileTime.} = + if s[0] != '0': + return s + if s[1] in {'b', 'o', 'x'}: + return s[2 .. ^1] + s + +func stripLeadingZeros(value: string): string {.compileTime.} = + var cidx = 0 + # ignore the last character so we retain '0' on zero value + while cidx < value.len - 1 and value[cidx] == '0': + cidx.inc + value[cidx .. ^1] + +func isOverflow(T: type SomeBigInteger, s: static string, radix: static uint8): bool {.compileTime.} = + # a stupid but effective overflow detection + # it's a compiletime check anyway + let tmp = parse(s, T, radix) + let litStr = tmp.toString(radix) + let normalizedSrc = s.stripPrefix.stripLeadingZeros + litStr != normalizedSrc diff --git a/tests/test_features.nim b/tests/test_features.nim index 814e87f..7534905 100644 --- a/tests/test_features.nim +++ b/tests/test_features.nim @@ -6,7 +6,12 @@ # # at your option. This file may not be copied, modified, or distributed except according to those terms. -import ../stint, unittest +import + ../stint, + unittest + +template reject(code: untyped) = + static: assert(not compiles(code)) suite "new features": test "custom literal": @@ -19,6 +24,8 @@ suite "new features": let x = 0b111100011'u128 y = 0o777766666'u256 + z = 0x1122334455667788991011121314151617181920aabbccddeeffb1b2b3b4b500'u256 + w = 340282366920938463463374607431768211455'u128 check: a == 0xabcdef0123456.u128 @@ -27,3 +34,21 @@ suite "new features": d == -50000.i256 x == 0b111100011.u128 y == 0o777766666.u256 + z == UInt256.fromHex("0x1122334455667788991011121314151617181920aabbccddeeffb1b2b3b4b500") + w == UInt128.fromDecimal("340282366920938463463374607431768211455") + + test "custom literal overflow": + reject: + const + z = 0x1122334455667788991011121314151617181920aabbccddeeffb1b2b3b4b5700'u256 + doAssert(false) + + reject: + let + z = 0x1122334455667788991011121314151617181920aabbccddeeffb1b2b3b4b5700'u256 + doAssert(false) + + reject: + const + w = 1122334455667788991011121314151617181920'u128 + doAssert(false)