fix modmul 256-bit perf (#156)

Not sure about other lengths, but this 100x's 256-bit `modmul` on the
given trivial benchmark and fixes abysmally slow EVM performance

```
Modmul (stint): 856300 ms
```

```
Modmul (stint): 8850 ms
```
This commit is contained in:
Jacek Sieka 2024-07-09 15:35:08 +02:00 committed by GitHub
parent 9a3348bd44
commit 7c81df9adc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 69 deletions

View File

@ -23,8 +23,8 @@ let a = [123'u64, 123'u64, 123'u64, 123'u64]
let m = [456'u64, 456'u64, 456'u64, 45'u64] let m = [456'u64, 456'u64, 456'u64, 45'u64]
proc add_stint(a, m: array[4, uint64]) = proc add_stint(a, m: array[4, uint64]) =
let aU256 = cast[Stuint[256]](a) let aU256 = cast[StUint[256]](a)
let mU256 = cast[Stuint[256]](m) let mU256 = cast[StUint[256]](m)
bench "Add (stint)": bench "Add (stint)":
var foo = aU256 var foo = aU256
@ -33,8 +33,8 @@ proc add_stint(a, m: array[4, uint64]) =
foo += aU256 foo += aU256
proc mul_stint(a, m: array[4, uint64]) = proc mul_stint(a, m: array[4, uint64]) =
let aU256 = cast[Stuint[256]](a) let aU256 = cast[StUint[256]](a)
let mU256 = cast[Stuint[256]](m) let mU256 = cast[StUint[256]](m)
bench "Mul (stint)": bench "Mul (stint)":
var foo = aU256 var foo = aU256
@ -42,17 +42,28 @@ proc mul_stint(a, m: array[4, uint64]) =
foo += (foo * foo) foo += (foo * foo)
proc mod_stint(a, m: array[4, uint64]) = proc mod_stint(a, m: array[4, uint64]) =
let aU256 = cast[Stuint[256]](a) let aU256 = cast[StUint[256]](a)
let mU256 = cast[Stuint[256]](m) let mU256 = cast[StUint[256]](m)
bench "Mod (stint)": bench "Mod (stint)":
var foo = aU256 var foo = aU256
for i in 0 ..< 100_000_000: for i in 0 ..< 100_000_000:
foo += (foo * foo) mod mU256 foo += (foo * foo) mod mU256
add_stint(a, m) proc mulmod_stint(a, m: array[4, uint64]) =
mul_stint(a, m) let aU256 = cast[StUint[256]](a)
mod_stint(a, m) let mU256 = cast[StUint[256]](m)
bench "Modmul (stint)":
var foo = aU256
for i in 0 ..< 100_000_000:
foo += mulmod(aU256, aU256, mU256)
# add_stint(a, m)
# mul_stint(a, m)
# mod_stint(a, m)
mulmod_stint(a, m)
when defined(bench_ttmath): when defined(bench_ttmath):
# need C++ # need C++

View File

@ -22,8 +22,9 @@ func addmod_internal(a, b, m: StUint): StUint {.inline.}=
let b_from_m = m - b let b_from_m = m - b
if a >= b_from_m: if a >= b_from_m:
return a - b_from_m a - b_from_m
return m - b_from_m + a else:
m - b_from_m + a
func submod_internal(a, b, m: StUint): StUint {.inline.}= func submod_internal(a, b, m: StUint): StUint {.inline.}=
## Modular substraction ## Modular substraction
@ -34,53 +35,9 @@ func submod_internal(a, b, m: StUint): StUint {.inline.}=
# We don't do a_m - b_m directly to avoid underflows # We don't do a_m - b_m directly to avoid underflows
if a >= b: if a >= b:
return a - b a - b
return m - b + a else:
m - b + a
func doublemod_internal(a, m: StUint): StUint {.inline.}=
## Double a modulo m. Assume a < m
## Internal proc - used in mulmod
doAssert a < m
result = a
if a >= m - a:
result -= m
result += a
func mulmod_internal(a, b, m: StUint): StUint {.inline.}=
## Does (a * b) mod m. Assume a < m and b < m
## Internal proc - used in powmod
doAssert a < m
doAssert b < m
var (a, b) = (a, b)
if b > a:
swap(a, b)
while not b.isZero:
if b.isOdd:
result = result.addmod_internal(a, m)
a = doublemod_internal(a, m)
b = b shr 1
func powmod_internal(a, b, m: StUint): StUint {.inline.}=
## Compute ``(a ^ b) mod m``, assume a < m
## Internal proc
doAssert a < m
var (a, b) = (a, b)
result = one(type a)
while not b.isZero:
if b.isOdd:
result = result.mulmod_internal(a, m)
b = b shr 1
a = mulmod_internal(a, a, m)
func addmod*(a, b, m: StUint): StUint = func addmod*(a, b, m: StUint): StUint =
## Modular addition ## Modular addition
@ -90,7 +47,7 @@ func addmod*(a, b, m: StUint): StUint =
let b_m = if b < m: b let b_m = if b < m: b
else: b mod m else: b mod m
result = addmod_internal(a_m, b_m, m) addmod_internal(a_m, b_m, m)
func submod*(a, b, m: StUint): StUint = func submod*(a, b, m: StUint): StUint =
## Modular substraction ## Modular substraction
@ -100,24 +57,29 @@ func submod*(a, b, m: StUint): StUint =
let b_m = if b < m: b let b_m = if b < m: b
else: b mod m else: b mod m
result = submod_internal(a_m, b_m, m) submod_internal(a_m, b_m, m)
func mulmod*(a, b, m: StUint): StUint = func mulmod*(a, b, m: StUint): StUint =
## Modular multiplication ## Modular multiplication
let a_m = if a < m: a let
else: a mod m ax = a.stuint(a.bits * 2)
let b_m = if b < m: b bx = b.stuint(b.bits * 2)
else: b mod m mx = m.stuint(m.bits * 2)
px = ax * bx
result = mulmod_internal(a_m, b_m, m) divmod(px, mx).rem.stuint(a.bits)
func powmod*(a, b, m: StUint): StUint = func powmod*(a, b, m: StUint): StUint =
## Modular exponentiation ## Modular exponentiation
let a_m = if a < m: a var (a, b) = (a, b)
else: a mod m result = one(type a)
result = powmod_internal(a_m, b, m) while not b.isZero:
if b.isOdd:
result = result.mulmod(a, m)
b = b shr 1
a = mulmod(a, a, m)
{.pop.} {.pop.}