add signed int modular arithmetic and tests
This commit is contained in:
parent
867739d2ca
commit
e99bc7ff89
|
@ -7,8 +7,8 @@
|
||||||
#
|
#
|
||||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||||
|
|
||||||
import stint/[io, uintops, intops, literals_stint, modular_arithmetic]
|
import stint/[io, uintops, intops, literals_stint, modular_arithmetic, int_modarith]
|
||||||
export io, uintops, intops, literals_stint, modular_arithmetic
|
export io, uintops, intops, literals_stint, modular_arithmetic, int_modarith
|
||||||
|
|
||||||
type
|
type
|
||||||
Int128* = StInt[128]
|
Int128* = StInt[128]
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
# Stint
|
||||||
|
# Copyright 2018-Present Status Research & Development GmbH
|
||||||
|
# Licensed under either of
|
||||||
|
#
|
||||||
|
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT)
|
||||||
|
#
|
||||||
|
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||||
|
|
||||||
|
import
|
||||||
|
./intops,
|
||||||
|
./modular_arithmetic,
|
||||||
|
private/datatypes
|
||||||
|
|
||||||
|
{.push raises: [], noinit, gcsafe.}
|
||||||
|
|
||||||
|
func addmod*(a, b, m: StInt): StInt =
|
||||||
|
## Modular addition
|
||||||
|
let mt = m.abs
|
||||||
|
if mt.isOne:
|
||||||
|
result.setZero
|
||||||
|
return
|
||||||
|
|
||||||
|
if a.isNegative and b.isNegative:
|
||||||
|
result.impl = addmod(a.neg.impl, b.neg.impl, mt.impl)
|
||||||
|
result.negate
|
||||||
|
elif a.isPositive and b.isPositive:
|
||||||
|
result.impl = addmod(a.impl, b.impl, mt.impl)
|
||||||
|
else:
|
||||||
|
result = a + b
|
||||||
|
result = result mod mt
|
||||||
|
|
||||||
|
func submod*(a, b, m: StInt): StInt =
|
||||||
|
## Modular substraction
|
||||||
|
let mt = m.abs
|
||||||
|
if mt.isOne:
|
||||||
|
result.setZero
|
||||||
|
return
|
||||||
|
|
||||||
|
if a.isNegative and b.isPositive:
|
||||||
|
result.impl = addmod(a.neg.impl, b.impl, mt.impl)
|
||||||
|
result.negate
|
||||||
|
elif a.isPositive and b.isNegative:
|
||||||
|
result.impl = addmod(a.impl, b.neg.impl, mt.impl)
|
||||||
|
else:
|
||||||
|
result = a - b
|
||||||
|
result = result mod mt
|
||||||
|
|
||||||
|
func mulmod*(a, b, m: StInt): StInt =
|
||||||
|
## Modular multiplication
|
||||||
|
|
||||||
|
let mAbs = m.abs
|
||||||
|
if (a.isNegative and b.isPositive) or
|
||||||
|
(a.isPositive and b.isNegative):
|
||||||
|
let xAbs = a.abs
|
||||||
|
let yAbs = b.abs
|
||||||
|
result.impl = mulmod(xAbs.impl, yAbs.impl, mAbs.impl)
|
||||||
|
result.negate
|
||||||
|
else:
|
||||||
|
var xAbs = a
|
||||||
|
var yAbs = b
|
||||||
|
if a.isNegative:
|
||||||
|
xAbs.negate
|
||||||
|
yAbs.negate
|
||||||
|
result.impl = mulmod(xAbs.impl, yAbs.impl, mAbs.impl)
|
||||||
|
|
||||||
|
func powmod*(base, exp, m: StInt): StInt {.raises: [ValueError].} =
|
||||||
|
## Modular exponentiation
|
||||||
|
|
||||||
|
if exp.isNegative:
|
||||||
|
raise newException(ValueError, "exponent must not be negative")
|
||||||
|
|
||||||
|
var
|
||||||
|
bv = base
|
||||||
|
switchSign = false
|
||||||
|
mAbs = m.abs
|
||||||
|
|
||||||
|
if base.isNegative:
|
||||||
|
bv.negate
|
||||||
|
switchSign = exp.isOdd
|
||||||
|
|
||||||
|
result.impl = powmod(bv.impl, exp.impl, mAbs.impl)
|
||||||
|
if switchSign:
|
||||||
|
result.negate
|
||||||
|
|
||||||
|
{.pop.}
|
|
@ -113,6 +113,9 @@ func low*[bits](_: typedesc[StInt[bits]]): StInt[bits] =
|
||||||
func isZero*(a: StInt): bool =
|
func isZero*(a: StInt): bool =
|
||||||
a.impl.isZero
|
a.impl.isZero
|
||||||
|
|
||||||
|
func isOne*(a: StInt): bool =
|
||||||
|
a.impl.isOne
|
||||||
|
|
||||||
func `==`*(a, b: StInt): bool =
|
func `==`*(a, b: StInt): bool =
|
||||||
## Signed int `equal` comparison
|
## Signed int `equal` comparison
|
||||||
a.impl == b.impl
|
a.impl == b.impl
|
||||||
|
|
|
@ -9,6 +9,8 @@
|
||||||
|
|
||||||
import ./uintops, private/datatypes
|
import ./uintops, private/datatypes
|
||||||
|
|
||||||
|
{.push raises: [], gcsafe.}
|
||||||
|
|
||||||
func addmod_internal(a, b, m: StUint): StUint {.inline.}=
|
func addmod_internal(a, b, m: StUint): StUint {.inline.}=
|
||||||
## Modular addition
|
## Modular addition
|
||||||
## ⚠⚠ Assume a < m and b < m
|
## ⚠⚠ Assume a < m and b < m
|
||||||
|
@ -117,3 +119,5 @@ func powmod*(a, b, m: StUint): StUint =
|
||||||
else: a mod m
|
else: a mod m
|
||||||
|
|
||||||
result = powmod_internal(a_m, b, m)
|
result = powmod_internal(a_m, b, m)
|
||||||
|
|
||||||
|
{.pop.}
|
||||||
|
|
|
@ -64,6 +64,14 @@ func isZero*(a: StUint): bool =
|
||||||
return false
|
return false
|
||||||
return true
|
return true
|
||||||
|
|
||||||
|
func isOne*(a: StUint): bool =
|
||||||
|
if a.limbs[0] != 1:
|
||||||
|
return false
|
||||||
|
for i in 1 ..< a.limbs.len:
|
||||||
|
if a.limbs[i] != 0:
|
||||||
|
return false
|
||||||
|
return true
|
||||||
|
|
||||||
func `==`*(a, b: StUint): bool {.inline.} =
|
func `==`*(a, b: StUint): bool {.inline.} =
|
||||||
## Unsigned `equal` comparison
|
## Unsigned `equal` comparison
|
||||||
for i in 0 ..< a.limbs.len:
|
for i in 0 ..< a.limbs.len:
|
||||||
|
|
|
@ -264,6 +264,13 @@ template testComparison(chk, tst: untyped) =
|
||||||
chkIsOdd(chk, "FFFFFFFFFFFFFFF", 128)
|
chkIsOdd(chk, "FFFFFFFFFFFFFFF", 128)
|
||||||
chkIsOdd(chk, "FFFFFFFFFFFFFFFFFF", 256)
|
chkIsOdd(chk, "FFFFFFFFFFFFFFFFFF", 256)
|
||||||
|
|
||||||
|
tst "isOne":
|
||||||
|
let x = 1.i128
|
||||||
|
chk x.isOne
|
||||||
|
|
||||||
|
let y = 1.i256
|
||||||
|
chk y.isOne
|
||||||
|
|
||||||
static:
|
static:
|
||||||
testComparison(ctCheck, ctTest)
|
testComparison(ctCheck, ctTest)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
# Stint
|
||||||
|
# Copyright 2018-Present Status Research & Development GmbH
|
||||||
|
# Licensed under either of
|
||||||
|
#
|
||||||
|
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT)
|
||||||
|
#
|
||||||
|
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||||
|
|
||||||
|
import ../stint, unittest, math, test_helpers
|
||||||
|
|
||||||
|
template chkAddMod(chk: untyped, a, b, m, c: string, bits: int) =
|
||||||
|
chk addmod(fromHex(StInt[bits], a), fromHex(StInt[bits], b), fromHex(StInt[bits], m)) == fromHex(StInt[bits], c)
|
||||||
|
|
||||||
|
template chkMulMod(chk: untyped, a, b, m, c: string, bits: int) =
|
||||||
|
chk mulmod(fromHex(StInt[bits], a), fromHex(StInt[bits], b), fromHex(StInt[bits], m)) == fromHex(StInt[bits], c)
|
||||||
|
|
||||||
|
template chkPowMod(chk: untyped, a, b, m, c: string, bits: int) =
|
||||||
|
chk powmod(fromHex(StInt[bits], a), fromHex(StInt[bits], b), fromHex(StInt[bits], m)) == fromHex(StInt[bits], c)
|
||||||
|
|
||||||
|
template testModArith(chk, tst: untyped) =
|
||||||
|
tst "addmod":
|
||||||
|
chkAddMod(chk, "F", "F", "7", "2", 128)
|
||||||
|
chkAddMod(chk, "AAAA", "AA", "F", "0", 128)
|
||||||
|
chkAddMod(chk, "BBBB", "AAAA", "9", "3", 128)
|
||||||
|
chkAddMod(chk, "BBBBBBBB", "AAAAAAAA", "9", "6", 128)
|
||||||
|
chkAddMod(chk, "BBBBBBBBBBBBBBBB", "AAAAAAAAAAAAAAAA", "9", "3", 128)
|
||||||
|
chk addmod(-5.i128, -5.i128, 3.i128) == -1.i128
|
||||||
|
chk addmod(5.i128, -9.i128, 3.i128) == -1.i128
|
||||||
|
chk addmod(-5.i128, 9.i128, 3.i128) == 1.i128
|
||||||
|
|
||||||
|
tst "submod":
|
||||||
|
chk submod(10.i128, 5.i128, 3.i128) == 2.i128
|
||||||
|
chk submod(-6.i128, -5.i128, 3.i128) == -1.i128
|
||||||
|
chk submod(5.i128, -9.i128, 3.i128) == 2.i128
|
||||||
|
chk submod(-5.i128, 9.i128, 3.i128) == -2.i128
|
||||||
|
|
||||||
|
tst "mulmod":
|
||||||
|
chk mulmod(10.i128, 5.i128, 3.i128) == 2.i128
|
||||||
|
chk mulmod(-7.i128, -5.i128, 3.i128) == 2.i128
|
||||||
|
chk mulmod(6.i128, -9.i128, 4.i128) == -2.i128
|
||||||
|
chk mulmod(-5.i128, 7.i128, 3.i128) == -2.i128
|
||||||
|
|
||||||
|
tst "powmod":
|
||||||
|
chk powmod(10.i128, 5.i128, 3.i128) == 1.i128
|
||||||
|
chk powmod(-7.i128, 4.i128, 3.i128) == 1.i128
|
||||||
|
chk powmod(-7.i128, 3.i128, 3.i128) == -1.i128
|
||||||
|
chk powmod(5.i128, 9.i128, 3.i128) == 2.i128
|
||||||
|
chk powmod(-5.i128, 9.i128, 3.i128) == -2.i128
|
||||||
|
|
||||||
|
static:
|
||||||
|
testModArith(ctCheck, ctTest)
|
||||||
|
|
||||||
|
suite "Wider unsigned Modular arithmetic coverage":
|
||||||
|
testModArith(check, test)
|
|
@ -943,7 +943,7 @@ proc main() =
|
||||||
test "Parsing an unexpected 0x prefix for a decimal string is a CatchableError and not a defect":
|
test "Parsing an unexpected 0x prefix for a decimal string is a CatchableError and not a defect":
|
||||||
let s = "0x123456"
|
let s = "0x123456"
|
||||||
|
|
||||||
expect(OverflowDefect):
|
expect(AssertionDefect):
|
||||||
discard parse(s, StUint[256], 10)
|
discard parse(s, StUint[256], 10)
|
||||||
|
|
||||||
suite "Testing conversion functions: Hex, Bytes, Endianness using secp256k1 curve":
|
suite "Testing conversion functions: Hex, Bytes, Endianness using secp256k1 curve":
|
||||||
|
|
Loading…
Reference in New Issue