From 7c81df9adc80088f46a4c2b8bf2a46c26fab057c Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Tue, 9 Jul 2024 15:35:08 +0200 Subject: [PATCH] fix modmul 256-bit perf (#156) Not sure about other lengths, but this 100x's 256-bit `modmul` on the given trivial benchmark and fixes abysmally slow EVM performance ``` Modmul (stint): 856300 ms ``` ``` Modmul (stint): 8850 ms ``` --- benchmarks/bench.nim | 31 +++++++++----- stint/modular_arithmetic.nim | 80 ++++++++++-------------------------- 2 files changed, 42 insertions(+), 69 deletions(-) diff --git a/benchmarks/bench.nim b/benchmarks/bench.nim index 36d9a55..9d44ff4 100644 --- a/benchmarks/bench.nim +++ b/benchmarks/bench.nim @@ -23,8 +23,8 @@ let a = [123'u64, 123'u64, 123'u64, 123'u64] let m = [456'u64, 456'u64, 456'u64, 45'u64] proc add_stint(a, m: array[4, uint64]) = - let aU256 = cast[Stuint[256]](a) - let mU256 = cast[Stuint[256]](m) + let aU256 = cast[StUint[256]](a) + let mU256 = cast[StUint[256]](m) bench "Add (stint)": var foo = aU256 @@ -33,8 +33,8 @@ proc add_stint(a, m: array[4, uint64]) = foo += aU256 proc mul_stint(a, m: array[4, uint64]) = - let aU256 = cast[Stuint[256]](a) - let mU256 = cast[Stuint[256]](m) + let aU256 = cast[StUint[256]](a) + let mU256 = cast[StUint[256]](m) bench "Mul (stint)": var foo = aU256 @@ -42,17 +42,28 @@ proc mul_stint(a, m: array[4, uint64]) = foo += (foo * foo) proc mod_stint(a, m: array[4, uint64]) = - let aU256 = cast[Stuint[256]](a) - let mU256 = cast[Stuint[256]](m) + let aU256 = cast[StUint[256]](a) + let mU256 = cast[StUint[256]](m) bench "Mod (stint)": var foo = aU256 for i in 0 ..< 100_000_000: foo += (foo * foo) mod mU256 -add_stint(a, m) -mul_stint(a, m) -mod_stint(a, m) +proc mulmod_stint(a, m: array[4, uint64]) = + let aU256 = cast[StUint[256]](a) + let mU256 = cast[StUint[256]](m) + + bench "Modmul (stint)": + var foo = aU256 + for i in 0 ..< 100_000_000: + foo += mulmod(aU256, aU256, mU256) + +# add_stint(a, m) +# mul_stint(a, m) +# mod_stint(a, m) + +mulmod_stint(a, m) when defined(bench_ttmath): # need C++ @@ -88,4 +99,4 @@ when defined(bench_ttmath): add_ttmath(a, m) mul_ttmath(a, m) - mod_ttmath(a, m) \ No newline at end of file + mod_ttmath(a, m) diff --git a/stint/modular_arithmetic.nim b/stint/modular_arithmetic.nim index 744b529..ebb72bc 100644 --- a/stint/modular_arithmetic.nim +++ b/stint/modular_arithmetic.nim @@ -22,8 +22,9 @@ func addmod_internal(a, b, m: StUint): StUint {.inline.}= let b_from_m = m - b if a >= b_from_m: - return a - b_from_m - return m - b_from_m + a + a - b_from_m + else: + m - b_from_m + a func submod_internal(a, b, m: StUint): StUint {.inline.}= ## Modular substraction @@ -34,53 +35,9 @@ func submod_internal(a, b, m: StUint): StUint {.inline.}= # We don't do a_m - b_m directly to avoid underflows if a >= b: - return a - b - return m - b + a - - -func doublemod_internal(a, m: StUint): StUint {.inline.}= - ## Double a modulo m. Assume a < m - ## Internal proc - used in mulmod - - doAssert a < m - - result = a - if a >= m - a: - result -= m - result += a - -func mulmod_internal(a, b, m: StUint): StUint {.inline.}= - ## Does (a * b) mod m. Assume a < m and b < m - ## Internal proc - used in powmod - - doAssert a < m - doAssert b < m - - var (a, b) = (a, b) - - if b > a: - swap(a, b) - - while not b.isZero: - if b.isOdd: - result = result.addmod_internal(a, m) - a = doublemod_internal(a, m) - b = b shr 1 - -func powmod_internal(a, b, m: StUint): StUint {.inline.}= - ## Compute ``(a ^ b) mod m``, assume a < m - ## Internal proc - - doAssert a < m - - var (a, b) = (a, b) - result = one(type a) - - while not b.isZero: - if b.isOdd: - result = result.mulmod_internal(a, m) - b = b shr 1 - a = mulmod_internal(a, a, m) + a - b + else: + m - b + a func addmod*(a, b, m: StUint): StUint = ## Modular addition @@ -90,7 +47,7 @@ func addmod*(a, b, m: StUint): StUint = let b_m = if b < m: b else: b mod m - result = addmod_internal(a_m, b_m, m) + addmod_internal(a_m, b_m, m) func submod*(a, b, m: StUint): StUint = ## Modular substraction @@ -100,24 +57,29 @@ func submod*(a, b, m: StUint): StUint = let b_m = if b < m: b else: b mod m - result = submod_internal(a_m, b_m, m) + submod_internal(a_m, b_m, m) func mulmod*(a, b, m: StUint): StUint = ## Modular multiplication - let a_m = if a < m: a - else: a mod m - let b_m = if b < m: b - else: b mod m + let + ax = a.stuint(a.bits * 2) + bx = b.stuint(b.bits * 2) + mx = m.stuint(m.bits * 2) + px = ax * bx - result = mulmod_internal(a_m, b_m, m) + divmod(px, mx).rem.stuint(a.bits) func powmod*(a, b, m: StUint): StUint = ## Modular exponentiation - let a_m = if a < m: a - else: a mod m + var (a, b) = (a, b) + result = one(type a) - result = powmod_internal(a_m, b, m) + while not b.isZero: + if b.isOdd: + result = result.mulmod(a, m) + b = b shr 1 + a = mulmod(a, a, m) {.pop.}