diff --git a/stint.nim b/stint.nim index 72d21e2..a3ba9ae 100644 --- a/stint.nim +++ b/stint.nim @@ -7,8 +7,8 @@ # # at your option. This file may not be copied, modified, or distributed except according to those terms. -import stint/[io, uintops, intops, literals_stint, modular_arithmetic] -export io, uintops, intops, literals_stint, modular_arithmetic +import stint/[io, uintops, intops, literals_stint, modular_arithmetic, int_modarith] +export io, uintops, intops, literals_stint, modular_arithmetic, int_modarith type Int128* = StInt[128] diff --git a/stint/int_modarith.nim b/stint/int_modarith.nim new file mode 100644 index 0000000..def67ac --- /dev/null +++ b/stint/int_modarith.nim @@ -0,0 +1,86 @@ +# Stint +# Copyright 2018-Present Status Research & Development GmbH +# Licensed under either of +# +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) +# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) +# +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + ./intops, + ./modular_arithmetic, + private/datatypes + +{.push raises: [], noinit, gcsafe.} + +func addmod*(a, b, m: StInt): StInt = + ## Modular addition + let mt = m.abs + if mt.isOne: + result.setZero + return + + if a.isNegative and b.isNegative: + result.impl = addmod(a.neg.impl, b.neg.impl, mt.impl) + result.negate + elif a.isPositive and b.isPositive: + result.impl = addmod(a.impl, b.impl, mt.impl) + else: + result = a + b + result = result mod mt + +func submod*(a, b, m: StInt): StInt = + ## Modular substraction + let mt = m.abs + if mt.isOne: + result.setZero + return + + if a.isNegative and b.isPositive: + result.impl = addmod(a.neg.impl, b.impl, mt.impl) + result.negate + elif a.isPositive and b.isNegative: + result.impl = addmod(a.impl, b.neg.impl, mt.impl) + else: + result = a - b + result = result mod mt + +func mulmod*(a, b, m: StInt): StInt = + ## Modular multiplication + + let mAbs = m.abs + if (a.isNegative and b.isPositive) or + (a.isPositive and b.isNegative): + let xAbs = a.abs + let yAbs = b.abs + result.impl = mulmod(xAbs.impl, yAbs.impl, mAbs.impl) + result.negate + else: + var xAbs = a + var yAbs = b + if a.isNegative: + xAbs.negate + yAbs.negate + result.impl = mulmod(xAbs.impl, yAbs.impl, mAbs.impl) + +func powmod*(base, exp, m: StInt): StInt {.raises: [ValueError].} = + ## Modular exponentiation + + if exp.isNegative: + raise newException(ValueError, "exponent must not be negative") + + var + bv = base + switchSign = false + mAbs = m.abs + + if base.isNegative: + bv.negate + switchSign = exp.isOdd + + result.impl = powmod(bv.impl, exp.impl, mAbs.impl) + if switchSign: + result.negate + +{.pop.} diff --git a/stint/intops.nim b/stint/intops.nim index 0160cd0..7949ffc 100644 --- a/stint/intops.nim +++ b/stint/intops.nim @@ -113,6 +113,9 @@ func low*[bits](_: typedesc[StInt[bits]]): StInt[bits] = func isZero*(a: StInt): bool = a.impl.isZero +func isOne*(a: StInt): bool = + a.impl.isOne + func `==`*(a, b: StInt): bool = ## Signed int `equal` comparison a.impl == b.impl diff --git a/stint/modular_arithmetic.nim b/stint/modular_arithmetic.nim index 8018366..744b529 100644 --- a/stint/modular_arithmetic.nim +++ b/stint/modular_arithmetic.nim @@ -9,6 +9,8 @@ import ./uintops, private/datatypes +{.push raises: [], gcsafe.} + func addmod_internal(a, b, m: StUint): StUint {.inline.}= ## Modular addition ## ⚠⚠ Assume a < m and b < m @@ -117,3 +119,5 @@ func powmod*(a, b, m: StUint): StUint = else: a mod m result = powmod_internal(a_m, b, m) + +{.pop.} diff --git a/stint/uintops.nim b/stint/uintops.nim index 6d9e676..0c4daf0 100644 --- a/stint/uintops.nim +++ b/stint/uintops.nim @@ -64,6 +64,14 @@ func isZero*(a: StUint): bool = return false return true +func isOne*(a: StUint): bool = + if a.limbs[0] != 1: + return false + for i in 1 ..< a.limbs.len: + if a.limbs[i] != 0: + return false + return true + func `==`*(a, b: StUint): bool {.inline.} = ## Unsigned `equal` comparison for i in 0 ..< a.limbs.len: diff --git a/tests/test_int_comparison.nim b/tests/test_int_comparison.nim index 0f7b291..83d4559 100644 --- a/tests/test_int_comparison.nim +++ b/tests/test_int_comparison.nim @@ -264,6 +264,13 @@ template testComparison(chk, tst: untyped) = chkIsOdd(chk, "FFFFFFFFFFFFFFF", 128) chkIsOdd(chk, "FFFFFFFFFFFFFFFFFF", 256) + tst "isOne": + let x = 1.i128 + chk x.isOne + + let y = 1.i256 + chk y.isOne + static: testComparison(ctCheck, ctTest) diff --git a/tests/test_int_modular_arithmetic.nim b/tests/test_int_modular_arithmetic.nim new file mode 100644 index 0000000..626bc2d --- /dev/null +++ b/tests/test_int_modular_arithmetic.nim @@ -0,0 +1,55 @@ +# Stint +# Copyright 2018-Present Status Research & Development GmbH +# Licensed under either of +# +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) +# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) +# +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import ../stint, unittest, math, test_helpers + +template chkAddMod(chk: untyped, a, b, m, c: string, bits: int) = + chk addmod(fromHex(StInt[bits], a), fromHex(StInt[bits], b), fromHex(StInt[bits], m)) == fromHex(StInt[bits], c) + +template chkMulMod(chk: untyped, a, b, m, c: string, bits: int) = + chk mulmod(fromHex(StInt[bits], a), fromHex(StInt[bits], b), fromHex(StInt[bits], m)) == fromHex(StInt[bits], c) + +template chkPowMod(chk: untyped, a, b, m, c: string, bits: int) = + chk powmod(fromHex(StInt[bits], a), fromHex(StInt[bits], b), fromHex(StInt[bits], m)) == fromHex(StInt[bits], c) + +template testModArith(chk, tst: untyped) = + tst "addmod": + chkAddMod(chk, "F", "F", "7", "2", 128) + chkAddMod(chk, "AAAA", "AA", "F", "0", 128) + chkAddMod(chk, "BBBB", "AAAA", "9", "3", 128) + chkAddMod(chk, "BBBBBBBB", "AAAAAAAA", "9", "6", 128) + chkAddMod(chk, "BBBBBBBBBBBBBBBB", "AAAAAAAAAAAAAAAA", "9", "3", 128) + chk addmod(-5.i128, -5.i128, 3.i128) == -1.i128 + chk addmod(5.i128, -9.i128, 3.i128) == -1.i128 + chk addmod(-5.i128, 9.i128, 3.i128) == 1.i128 + + tst "submod": + chk submod(10.i128, 5.i128, 3.i128) == 2.i128 + chk submod(-6.i128, -5.i128, 3.i128) == -1.i128 + chk submod(5.i128, -9.i128, 3.i128) == 2.i128 + chk submod(-5.i128, 9.i128, 3.i128) == -2.i128 + + tst "mulmod": + chk mulmod(10.i128, 5.i128, 3.i128) == 2.i128 + chk mulmod(-7.i128, -5.i128, 3.i128) == 2.i128 + chk mulmod(6.i128, -9.i128, 4.i128) == -2.i128 + chk mulmod(-5.i128, 7.i128, 3.i128) == -2.i128 + + tst "powmod": + chk powmod(10.i128, 5.i128, 3.i128) == 1.i128 + chk powmod(-7.i128, 4.i128, 3.i128) == 1.i128 + chk powmod(-7.i128, 3.i128, 3.i128) == -1.i128 + chk powmod(5.i128, 9.i128, 3.i128) == 2.i128 + chk powmod(-5.i128, 9.i128, 3.i128) == -2.i128 + +static: + testModArith(ctCheck, ctTest) + +suite "Wider unsigned Modular arithmetic coverage": + testModArith(check, test) diff --git a/tests/test_io.nim b/tests/test_io.nim index 8d9f879..8299bc3 100644 --- a/tests/test_io.nim +++ b/tests/test_io.nim @@ -943,7 +943,7 @@ proc main() = test "Parsing an unexpected 0x prefix for a decimal string is a CatchableError and not a defect": let s = "0x123456" - expect(OverflowDefect): + expect(AssertionDefect): discard parse(s, StUint[256], 10) suite "Testing conversion functions: Hex, Bytes, Endianness using secp256k1 curve":