From b7687ddc4a5dde8ce01f597f7d4e4396b34f38fc Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Mon, 3 Jul 2023 01:45:36 +0200 Subject: [PATCH] Accelerate eth_evm_modexp by 25x by dividing input size by 8 (#249) * Accelerate eth_evm_modexp by 25x by dividing input size by 8 (scales quadratically) * instant exponentiation by power of 2 depending on trailing zeroes * improve bench report * rename * rewrite the pow2k even/trailingZero accel * eth_evm_modexp: remove leftover TimeEffect --- benchmarks/bench_evm_modexp_dos.nim | 69 +++++++++++++++++++ benchmarks/bench_powmod.nim | 2 +- constantine/ethereum_evm_precompiles.nim | 44 ++++++------ .../arithmetic/bigints_views.nim | 15 ++-- .../arithmetic/limbs_mod2k.nim | 17 +++-- 5 files changed, 115 insertions(+), 32 deletions(-) create mode 100644 benchmarks/bench_evm_modexp_dos.nim diff --git a/benchmarks/bench_evm_modexp_dos.nim b/benchmarks/bench_evm_modexp_dos.nim new file mode 100644 index 0000000..1fc8c37 --- /dev/null +++ b/benchmarks/bench_evm_modexp_dos.nim @@ -0,0 +1,69 @@ +import + ../constantine/ethereum_evm_precompiles, + ./platforms, ./bench_blueprint, + + ../constantine/platforms/codecs + +proc report(op: string, elapsedNs: int64, elapsedCycles: int64, iters: int) = + let ns = elapsedNs div iters + let cycles = elapsedCycles div iters + let throughput = 1e9 / float64(ns) + when SupportsGetTicks: + echo &"{op:<45} {throughput:>15.3f} ops/s {ns:>16} ns/op {cycles:>12} CPU cycles (approx)" + else: + echo &"{op:<45} {throughput:>15.3f} ops/s {ns:>16} ns/op" + +template bench(fnCall: untyped, ticks, ns: var int64): untyped = + block: + let startTime = getMonotime() + let startClock = getTicks() + fnCall + let stopClock = getTicks() + let stopTime = getMonotime() + + ticks += stopClock - startClock + ns += inNanoseconds(stopTime-startTime) + +proc main() = + + let input = [ + # Length of base (32) + (uint8)0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, + + # Length of exponent (32) + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, + + # Length of modulus (32) + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, + + # Base (96064778440517843452771003943013638877275214272712651271554889917016327417616) + 0xd4, 0x62, 0xbc, 0xde, 0x8f, 0x57, 0xb0, 0x4a, 0x3f, 0xe1, 0x16, 0xc8, 0x12, 0x8c, 0x44, 0x34, + 0xcf, 0x10, 0x25, 0x2e, 0x48, 0xa3, 0xcc, 0x0d, 0x28, 0xdf, 0x2b, 0xac, 0x4a, 0x8d, 0x6f, 0x10, + + # Exponent (96064778440517843452771003943013638877275214272712651271554889917016327417616) + 0xd4, 0x62, 0xbc, 0xde, 0x8f, 0x57, 0xb0, 0x4a, 0x3f, 0xe1, 0x16, 0xc8, 0x12, 0x8c, 0x44, 0x34, + 0xcf, 0x10, 0x25, 0x2e, 0x48, 0xa3, 0xcc, 0x0d, 0x28, 0xdf, 0x2b, 0xac, 0x4a, 0x8d, 0x6f, 0x10, + + # Modulus (57896044618658097711785492504343953926634992332820282019728792003956564819968) + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ] + + var r = newSeq[byte](32) + var ticks, nanoseconds: int64 + + const Iters = 22058 + + for i in 0 ..< Iters: + bench( + (let _ = r.eth_evm_modexp(input)), + ticks, nanoseconds) + + report("EVM Modexp", nanoseconds, ticks, Iters) + echo "Total time: ", nanoseconds.float64 / 1e6, " ms" + +main() \ No newline at end of file diff --git a/benchmarks/bench_powmod.nim b/benchmarks/bench_powmod.nim index 4f2daf4..64b6bdd 100644 --- a/benchmarks/bench_powmod.nim +++ b/benchmarks/bench_powmod.nim @@ -47,7 +47,7 @@ from bigints import nil # force qualified import to avoid conflicts on BigInt proc report(op: string, elapsedNs: int64, elapsedCycles: int64, iters: int) = let ns = elapsedNs div iters let cycles = elapsedCycles div iters - let throughput = 1e9 / float64(elapsedNs) + let throughput = 1e9 / float64(ns) when SupportsGetTicks: echo &"{op:<45} {throughput:>15.3f} ops/s {ns:>16} ns/op {cycles:>12} CPU cycles (approx)" else: diff --git a/constantine/ethereum_evm_precompiles.nim b/constantine/ethereum_evm_precompiles.nim index 5534ffe..193e7e9 100644 --- a/constantine/ethereum_evm_precompiles.nim +++ b/constantine/ethereum_evm_precompiles.nim @@ -423,49 +423,53 @@ func eth_evm_modexp*(r: var openArray[byte], inputs: openArray[byte]): CttEVMSta return cttEVM_InvalidInputSize let - baseLen = cast[int](bL.limbs[0]) - exponentLen = cast[int](eL.limbs[0]) - modulusLen = cast[int](mL.limbs[0]) + baseByteLen = cast[int](bL.limbs[0]) + exponentByteLen = cast[int](eL.limbs[0]) + modulusByteLen = cast[int](mL.limbs[0]) - if r.len != modulusLen: + baseWordLen = baseByteLen.ceilDiv_vartime(WordBitWidth div 8) + modulusWordLen = modulusByteLen.ceilDiv_vartime(WordBitWidth div 8) + + if r.len != modulusByteLen: return cttEVM_InvalidOutputSize - if baseLen.ceilDiv_vartime(WordBitWidth div 8) > modulusLen.ceilDiv_vartime(WordBitWidth div 8): + if baseWordLen > modulusWordLen: return cttEVM_InvalidInputSize # Special cases # ---------------------- - if modulusLen == 0: + if modulusByteLen == 0: r.setZero() return cttEVM_Success - if exponentLen == 0: + if exponentByteLen == 0: r.setZero() r[r.len-1] = byte 1 # 0^0 = 1 and x^0 = 1 return cttEVM_Success - if baseLen == 0: + if baseByteLen == 0: r.setZero() return cttEVM_Success # Input deserialization # --------------------- - var baseBuf = allocStackArray(SecretWord, baseLen) - var modulusBuf = allocStackArray(SecretWord, modulusLen) - var outputBuf = allocStackArray(SecretWord, modulusLen) - - template base(): untyped = baseBuf.toOpenArray(0, baseLen-1) - template modulus(): untyped = modulusBuf.toOpenArray(0, modulusLen-1) - template output(): untyped = outputBuf.toOpenArray(0, modulusLen-1) # Inclusive stops let baseStart = 96 - let baseStop = baseStart+baseLen-1 + let baseStop = baseStart+baseByteLen-1 let expStart = baseStop+1 - let expStop = expStart+exponentLen-1 + let expStop = expStart+exponentByteLen-1 let modStart = expStop+1 - let modStop = modStart+modulusLen-1 + let modStop = modStart+modulusByteLen-1 - base.toOpenArray(0, baseLen-1).unmarshal(inputs.toOpenArray(baseStart, baseStop), WordBitWidth, bigEndian) - modulus.toOpenArray(0, modulusLen-1).unmarshal(inputs.toOpenArray(modStart, modStop), WordBitWidth, bigEndian) + var baseBuf = allocStackArray(SecretWord, baseWordLen) + var modulusBuf = allocStackArray(SecretWord, modulusWordLen) + var outputBuf = allocStackArray(SecretWord, modulusWordLen) + + template base(): untyped = baseBuf.toOpenArray(0, baseWordLen-1) + template modulus(): untyped = modulusBuf.toOpenArray(0, modulusWordLen-1) + template output(): untyped = outputBuf.toOpenArray(0, modulusWordLen-1) + + base.toOpenArray(0, baseWordLen-1).unmarshal(inputs.toOpenArray(baseStart, baseStop), WordBitWidth, bigEndian) + modulus.toOpenArray(0, modulusWordLen-1).unmarshal(inputs.toOpenArray(modStart, modStop), WordBitWidth, bigEndian) template exponent(): untyped = inputs.toOpenArray(expStart, expStop) diff --git a/constantine/math_arbitrary_precision/arithmetic/bigints_views.nim b/constantine/math_arbitrary_precision/arithmetic/bigints_views.nim index 0c953f4..7bbd79e 100644 --- a/constantine/math_arbitrary_precision/arithmetic/bigints_views.nim +++ b/constantine/math_arbitrary_precision/arithmetic/bigints_views.nim @@ -131,15 +131,16 @@ func powMod_vartime*( # Even modulus # ------------------------------------------------------------------- - var i = 0 + let ctz = block: + var i = 0 - # Find the first non-zero word from right-to-left. (a != 0) - while i < M.len-1: - if bool(M[i] != Zero): - break - i += 1 + # Find the first non-zero word from right-to-left. (a != 0) + while i < M.len-1: + if bool(M[i] != Zero): + break + i += 1 - let ctz = int(countTrailingZeroBits_vartime(BaseType M[i])) + + int(countTrailingZeroBits_vartime(BaseType M[i])) + WordBitWidth*i # Even modulus: power of two (mod 2ᵏ) diff --git a/constantine/math_arbitrary_precision/arithmetic/limbs_mod2k.nim b/constantine/math_arbitrary_precision/arithmetic/limbs_mod2k.nim index 303275a..c9f2fc9 100644 --- a/constantine/math_arbitrary_precision/arithmetic/limbs_mod2k.nim +++ b/constantine/math_arbitrary_precision/arithmetic/limbs_mod2k.nim @@ -112,9 +112,18 @@ func powMod2k_vartime*( r[0] = One # x⁰ = 1, even for 0⁰ return - if a.isEven().bool and # if a is even - 1+msb >= k.int: # The msb of a n-bit integer is at n-1 - return # r ≡ aᵉ (mod 2ᵏ) ≡ (2b)ᵏ⁺ⁿ (mod 2ᵏ) ≡ 2ᵏ.2ⁿ.bᵏ⁺ⁿ (mod 2ᵏ) ≡ 0 (mod 2ᵏ) + if a.isEven().bool: + let aTrailingZeroes = block: + var i = 0 + while i < a.len-1: + if bool(a[i] != Zero): + break + i += 1 + int(countTrailingZeroBits_vartime(BaseType a[i])) + + WordBitWidth*i + # if a is even, a = 2b and if e > k then there exists n such that e = k+n + if aTrailingZeroes+msb >= k.int: # r ≡ aᵉ (mod 2ᵏ) ≡ (2b)ᵏ⁺ⁿ (mod 2ᵏ) ≡ 2ᵏ.2ⁿ.bᵏ⁺ⁿ (mod 2ᵏ) ≡ 0 (mod 2ᵏ) + return # we can generalize to a = 2ᵗᶻb with tz the number of trailing zeros. var bitsLeft = msb+1 if a.isOdd().bool and # if a is odd @@ -130,7 +139,7 @@ func powMod2k_vartime*( # range [r.len, a.len) will be truncated (mod 2ᵏ) sBuf[i] = a[i] - # TODO: sliding window + # TODO: sliding/fixed window exponentiation for i in countdown(exponent.len-1, 0): for bit in unpackLE(exponent[i]): if bit: