From f2d0eab153971430c3383bc118e9f0b2f43b50a7 Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Wed, 16 May 2018 10:41:46 +0200 Subject: [PATCH] Modular arithmetic (#47) * Add isEven and isOdd functions * Add modular add, mul, sub pow fixes #18 --- stint.nim | 4 +- stint/int_public.nim | 15 +++- stint/modular_arithmetic.nim | 119 ++++++++++++++++++++++++++++++ stint/private/int_comparison.nim | 6 ++ stint/private/uint_comparison.nim | 6 ++ stint/private/uint_exp.nim | 4 +- stint/uint_public.nim | 13 +++- tests/test_int_comparison.nim | 9 +++ tests/test_uint_comparison.nim | 9 +++ tests/uint_modular_arithmetic.nim | 46 ++++++++++++ 10 files changed, 224 insertions(+), 7 deletions(-) create mode 100644 stint/modular_arithmetic.nim create mode 100644 tests/uint_modular_arithmetic.nim diff --git a/stint.nim b/stint.nim index 925453b..007e08d 100644 --- a/stint.nim +++ b/stint.nim @@ -7,8 +7,8 @@ # # at your option. This file may not be copied, modified, or distributed except according to those terms. -import stint/[uint_public, int_public, io] -export uint_public, int_public, io +import stint/[uint_public, int_public, io, modular_arithmetic] +export uint_public, int_public, io, modular_arithmetic type Int128* = Stint[128] diff --git a/stint/int_public.nim b/stint/int_public.nim index 6bc0e7f..f058e45 100644 --- a/stint/int_public.nim +++ b/stint/int_public.nim @@ -57,8 +57,19 @@ import ./private/int_comparison make_binary(`<`, bool) make_binary(`<=`, bool) make_binary(`==`, bool) -func isZero*(x: Stint): bool {.inline.} = isZero x.data -func isNegative*(x: Stint): bool {.inline.} = isNegative x.data +make_unary(isZero, bool) +make_unary(isNegative, bool) + +func isOdd(x: SomeSignedInt): bool {.inline.}= + # internal + bool(x and 1) + +func isEven(x: SomeSignedInt): bool {.inline.}= + # internal + not x.isOdd + +make_unary(isOdd, bool) +make_unary(isEven, bool) import ./private/int_bitwise_ops diff --git a/stint/modular_arithmetic.nim b/stint/modular_arithmetic.nim new file mode 100644 index 0000000..773904d --- /dev/null +++ b/stint/modular_arithmetic.nim @@ -0,0 +1,119 @@ +# Stint +# Copyright 2018 Status Research & Development GmbH +# Licensed under either of +# +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) +# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) +# +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import ./uint_public + +func addmod_internal(a, b, m: Stuint): Stuint {.inline.}= + ## Modular addition + ## ⚠⚠ Assume a < m and b < m + + assert a < m + assert b < m + + # We don't do a_m + b_m directly to avoid overflows + let b_from_m = m - b + + if a >= b_from_m: + return a - b_from_m + return m - b_from_m + a + +func submod_internal(a, b, m: Stuint): Stuint {.inline.}= + ## Modular substraction + ## ⚠⚠ Assume a < m and b < m + + assert a < m + assert b < m + + # We don't do a_m - b_m directly to avoid underflows + if a >= b: + return a - b + return m - b + a + + +func doublemod_internal(a, m: Stuint): Stuint {.inline.}= + ## Double a modulo m. Assume a < m + ## Internal proc - used in mulmod + + assert a < m + + result = a + if a >= m - a: + result -= m + result += a + +func mulmod_internal(a, b, m: Stuint): Stuint {.inline.}= + ## Does (a * b) mod m. Assume a < m and b < m + ## Internal proc - used in powmod + + assert a < m + assert b < m + + var (a, b) = (a, b) + + if b > a: + swap(a, b) + + while not b.isZero: + if b.isOdd: + result = result.addmod_internal(a, m) + a = doublemod_internal(a, m) + b = b shr 1 + +func powmod_internal(a, b, m: Stuint): Stuint {.inline.}= + ## Compute ``(a ^ b) mod m``, assume a < m + ## Internal proc + + assert a < m + + var (a, b) = (a, b) + result = one(type a) + + while not b.isZero: + if b.isOdd: + result = result.mulmod_internal(a, m) + b = b shr 1 + a = mulmod_internal(a, a, m) + +func addmod*(a, b, m: Stuint): Stuint = + ## Modular addition + + let a_m = if a < m: a + else: a mod m + let b_m = if b < m: b + else: b mod m + + result = addmod_internal(a_m, b_m, m) + +proc submod*(a, b, m: Stuint): Stuint = + ## Modular substraction + + let a_m = if a < m: a + else: a mod m + let b_m = if b < m: b + else: b mod m + + result = submod_internal(a_m, b_m, m) + +func mulmod*(a, b, m: Stuint): Stuint = + ## Modular multiplication + + let a_m = if a < m: a + else: a mod m + let b_m = if b < m: b + else: b mod m + + result = mulmod_internal(a_m, b_m, m) + +proc powmod*(a, b, m: Stuint): Stuint = + ## Modular exponentiation + + let a_m = if a < m: a + else: a mod m + + result = powmod_internal(a_m, b, m) diff --git a/stint/private/int_comparison.nim b/stint/private/int_comparison.nim index 6c067d7..0f1bdb7 100644 --- a/stint/private/int_comparison.nim +++ b/stint/private/int_comparison.nim @@ -43,3 +43,9 @@ func `<=`*(x, y: IntImpl): bool {.inline.}= if x != y: return x < y return true # they're equal + +func isOdd*(x: IntImpl): bool {.inline.}= + bool(x.least_significant_word and 1) + +func isEven*(x: IntImpl): bool {.inline.}= + not x.isOdd diff --git a/stint/private/uint_comparison.nim b/stint/private/uint_comparison.nim index 0a2509a..6bb5d69 100644 --- a/stint/private/uint_comparison.nim +++ b/stint/private/uint_comparison.nim @@ -38,3 +38,9 @@ func `<=`*(x, y: UintImpl): bool {.inline.}= if x != y: return x < y return true # they're equal + +func isOdd*(x: UintImpl): bool {.inline.}= + bool(x.least_significant_word and 1) + +func isEven*(x: UintImpl): bool {.inline.}= + not x.isOdd diff --git a/stint/private/uint_exp.nim b/stint/private/uint_exp.nim index 01db358..1103ed7 100644 --- a/stint/private/uint_exp.nim +++ b/stint/private/uint_exp.nim @@ -23,7 +23,7 @@ func pow*(x: UintImpl, y: Natural): UintImpl = result = one(type x) while true: - if (y and 1) != 0: + if bool(y and 1): # if y is odd result = result * x y = y shr 1 if y == 0: @@ -42,7 +42,7 @@ func pow*(x: UintImpl, y: UintImpl): UintImpl = result = one(type x) while true: - if not (y and one(type y)).isZero: + if y.isOdd: result = result * x y = y shr 1 if y.isZero: diff --git a/stint/uint_public.nim b/stint/uint_public.nim index e1fa69b..2ff55b4 100644 --- a/stint/uint_public.nim +++ b/stint/uint_public.nim @@ -53,7 +53,18 @@ import ./private/uint_comparison make_binary(`<`, bool) make_binary(`<=`, bool) make_binary(`==`, bool) -func isZero*(x: StUint): bool {.inline.} = isZero x.data +make_unary(isZero, bool) + +func isOdd(x: SomeUnsignedInt): bool {.inline.}= + # internal + bool(x and 1) + +func isEven(x: SomeUnsignedInt): bool {.inline.}= + # internal + not x.isOdd + +make_unary(isOdd, bool) +make_unary(isEven, bool) import ./private/uint_bitwise_ops diff --git a/tests/test_int_comparison.nim b/tests/test_int_comparison.nim index 5ab3f20..ba49c36 100644 --- a/tests/test_int_comparison.nim +++ b/tests/test_int_comparison.nim @@ -54,3 +54,12 @@ suite "Signed int - Testing comparison operators": a >= -c b >= -c -b >= -b + + test "isOdd/isEven": + check: + a.isEven + not a.isOdd + b.isOdd + not b.isEven + c.isEven + not c.isOdd diff --git a/tests/test_uint_comparison.nim b/tests/test_uint_comparison.nim index 274da04..35b9958 100644 --- a/tests/test_uint_comparison.nim +++ b/tests/test_uint_comparison.nim @@ -53,3 +53,12 @@ suite "Testing unsigned int comparison operators": cast[StUint[16]](c) >= a * b d >= e f >= d + + test "isOdd/isEven": + check: + a.isEven + not a.isOdd + b.isOdd + not b.isEven + c.isEven + not c.isOdd diff --git a/tests/uint_modular_arithmetic.nim b/tests/uint_modular_arithmetic.nim new file mode 100644 index 0000000..1e137d0 --- /dev/null +++ b/tests/uint_modular_arithmetic.nim @@ -0,0 +1,46 @@ +# Stint +# Copyright 2018 Status Research & Development GmbH +# Licensed under either of +# +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) +# * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) +# +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import ../stint, unittest, math + +suite "Modular arithmetic": + test "Modular addition": + + # uint16 rolls over at 65535 + let a = 50000.stuint(16) + let b = 20000.stuint(16) + let m = 60000.stuint(16) + + check: addmod(a, b, m) == 10000.stuint(16) + + test "Modular substraction": + + let a = 5.stuint(16) + let b = 7.stuint(16) + let m = 20.stuint(16) + + check: submod(a, b, m) == 18.stuint(16) + + test "Modular multiplication": + # https://www.wolframalpha.com/input/?i=(1234567890+*+987654321)+mod+999999999 + # --> 345_679_002 + let a = 1234567890.stuint(64) + let b = 987654321.stuint(64) + let m = 999999999.stuint(64) + + check: mulmod(a, b, m) == 345_679_002.stuint(64) + + test "Modular exponentiation": + # https://www.khanacademy.org/computing/computer-science/cryptography/modarithmetic/a/fast-modular-exponentiation + check: + powmod(5.stuint(16), 117.stuint(16), 19.stuint(16)) == 1.stuint(16) + powmod(3.stuint(16), 1993.stuint(16), 17.stuint(16)) == 14.stuint(16) + + check: + powmod(12.stuint(256), 34.stuint(256), high(UInt256)) == "4922235242952026704037113243122008064".u256