Accelerate eth_evm_modexp by 25x by dividing input size by 8 (#249)

* Accelerate eth_evm_modexp by 25x by dividing input size by 8 (scales quadratically)

* instant exponentiation by power of 2 depending on trailing zeroes

* improve bench report

* rename

* rewrite the pow2k even/trailingZero accel

* eth_evm_modexp: remove leftover TimeEffect
This commit is contained in:
Mamy Ratsimbazafy 2023-07-03 01:45:36 +02:00 committed by GitHub
parent d0f4ad8cda
commit b7687ddc4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 115 additions and 32 deletions

View File

@ -0,0 +1,69 @@
import
../constantine/ethereum_evm_precompiles,
./platforms, ./bench_blueprint,
../constantine/platforms/codecs
proc report(op: string, elapsedNs: int64, elapsedCycles: int64, iters: int) =
let ns = elapsedNs div iters
let cycles = elapsedCycles div iters
let throughput = 1e9 / float64(ns)
when SupportsGetTicks:
echo &"{op:<45} {throughput:>15.3f} ops/s {ns:>16} ns/op {cycles:>12} CPU cycles (approx)"
else:
echo &"{op:<45} {throughput:>15.3f} ops/s {ns:>16} ns/op"
template bench(fnCall: untyped, ticks, ns: var int64): untyped =
block:
let startTime = getMonotime()
let startClock = getTicks()
fnCall
let stopClock = getTicks()
let stopTime = getMonotime()
ticks += stopClock - startClock
ns += inNanoseconds(stopTime-startTime)
proc main() =
let input = [
# Length of base (32)
(uint8)0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20,
# Length of exponent (32)
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20,
# Length of modulus (32)
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20,
# Base (96064778440517843452771003943013638877275214272712651271554889917016327417616)
0xd4, 0x62, 0xbc, 0xde, 0x8f, 0x57, 0xb0, 0x4a, 0x3f, 0xe1, 0x16, 0xc8, 0x12, 0x8c, 0x44, 0x34,
0xcf, 0x10, 0x25, 0x2e, 0x48, 0xa3, 0xcc, 0x0d, 0x28, 0xdf, 0x2b, 0xac, 0x4a, 0x8d, 0x6f, 0x10,
# Exponent (96064778440517843452771003943013638877275214272712651271554889917016327417616)
0xd4, 0x62, 0xbc, 0xde, 0x8f, 0x57, 0xb0, 0x4a, 0x3f, 0xe1, 0x16, 0xc8, 0x12, 0x8c, 0x44, 0x34,
0xcf, 0x10, 0x25, 0x2e, 0x48, 0xa3, 0xcc, 0x0d, 0x28, 0xdf, 0x2b, 0xac, 0x4a, 0x8d, 0x6f, 0x10,
# Modulus (57896044618658097711785492504343953926634992332820282019728792003956564819968)
0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
]
var r = newSeq[byte](32)
var ticks, nanoseconds: int64
const Iters = 22058
for i in 0 ..< Iters:
bench(
(let _ = r.eth_evm_modexp(input)),
ticks, nanoseconds)
report("EVM Modexp", nanoseconds, ticks, Iters)
echo "Total time: ", nanoseconds.float64 / 1e6, " ms"
main()

View File

@ -47,7 +47,7 @@ from bigints import nil # force qualified import to avoid conflicts on BigInt
proc report(op: string, elapsedNs: int64, elapsedCycles: int64, iters: int) = proc report(op: string, elapsedNs: int64, elapsedCycles: int64, iters: int) =
let ns = elapsedNs div iters let ns = elapsedNs div iters
let cycles = elapsedCycles div iters let cycles = elapsedCycles div iters
let throughput = 1e9 / float64(elapsedNs) let throughput = 1e9 / float64(ns)
when SupportsGetTicks: when SupportsGetTicks:
echo &"{op:<45} {throughput:>15.3f} ops/s {ns:>16} ns/op {cycles:>12} CPU cycles (approx)" echo &"{op:<45} {throughput:>15.3f} ops/s {ns:>16} ns/op {cycles:>12} CPU cycles (approx)"
else: else:

View File

@ -423,49 +423,53 @@ func eth_evm_modexp*(r: var openArray[byte], inputs: openArray[byte]): CttEVMSta
return cttEVM_InvalidInputSize return cttEVM_InvalidInputSize
let let
baseLen = cast[int](bL.limbs[0]) baseByteLen = cast[int](bL.limbs[0])
exponentLen = cast[int](eL.limbs[0]) exponentByteLen = cast[int](eL.limbs[0])
modulusLen = cast[int](mL.limbs[0]) modulusByteLen = cast[int](mL.limbs[0])
if r.len != modulusLen: baseWordLen = baseByteLen.ceilDiv_vartime(WordBitWidth div 8)
modulusWordLen = modulusByteLen.ceilDiv_vartime(WordBitWidth div 8)
if r.len != modulusByteLen:
return cttEVM_InvalidOutputSize return cttEVM_InvalidOutputSize
if baseLen.ceilDiv_vartime(WordBitWidth div 8) > modulusLen.ceilDiv_vartime(WordBitWidth div 8): if baseWordLen > modulusWordLen:
return cttEVM_InvalidInputSize return cttEVM_InvalidInputSize
# Special cases # Special cases
# ---------------------- # ----------------------
if modulusLen == 0: if modulusByteLen == 0:
r.setZero() r.setZero()
return cttEVM_Success return cttEVM_Success
if exponentLen == 0: if exponentByteLen == 0:
r.setZero() r.setZero()
r[r.len-1] = byte 1 # 0^0 = 1 and x^0 = 1 r[r.len-1] = byte 1 # 0^0 = 1 and x^0 = 1
return cttEVM_Success return cttEVM_Success
if baseLen == 0: if baseByteLen == 0:
r.setZero() r.setZero()
return cttEVM_Success return cttEVM_Success
# Input deserialization # Input deserialization
# --------------------- # ---------------------
var baseBuf = allocStackArray(SecretWord, baseLen)
var modulusBuf = allocStackArray(SecretWord, modulusLen)
var outputBuf = allocStackArray(SecretWord, modulusLen)
template base(): untyped = baseBuf.toOpenArray(0, baseLen-1)
template modulus(): untyped = modulusBuf.toOpenArray(0, modulusLen-1)
template output(): untyped = outputBuf.toOpenArray(0, modulusLen-1)
# Inclusive stops # Inclusive stops
let baseStart = 96 let baseStart = 96
let baseStop = baseStart+baseLen-1 let baseStop = baseStart+baseByteLen-1
let expStart = baseStop+1 let expStart = baseStop+1
let expStop = expStart+exponentLen-1 let expStop = expStart+exponentByteLen-1
let modStart = expStop+1 let modStart = expStop+1
let modStop = modStart+modulusLen-1 let modStop = modStart+modulusByteLen-1
base.toOpenArray(0, baseLen-1).unmarshal(inputs.toOpenArray(baseStart, baseStop), WordBitWidth, bigEndian) var baseBuf = allocStackArray(SecretWord, baseWordLen)
modulus.toOpenArray(0, modulusLen-1).unmarshal(inputs.toOpenArray(modStart, modStop), WordBitWidth, bigEndian) var modulusBuf = allocStackArray(SecretWord, modulusWordLen)
var outputBuf = allocStackArray(SecretWord, modulusWordLen)
template base(): untyped = baseBuf.toOpenArray(0, baseWordLen-1)
template modulus(): untyped = modulusBuf.toOpenArray(0, modulusWordLen-1)
template output(): untyped = outputBuf.toOpenArray(0, modulusWordLen-1)
base.toOpenArray(0, baseWordLen-1).unmarshal(inputs.toOpenArray(baseStart, baseStop), WordBitWidth, bigEndian)
modulus.toOpenArray(0, modulusWordLen-1).unmarshal(inputs.toOpenArray(modStart, modStop), WordBitWidth, bigEndian)
template exponent(): untyped = template exponent(): untyped =
inputs.toOpenArray(expStart, expStop) inputs.toOpenArray(expStart, expStop)

View File

@ -131,15 +131,16 @@ func powMod_vartime*(
# Even modulus # Even modulus
# ------------------------------------------------------------------- # -------------------------------------------------------------------
var i = 0 let ctz = block:
var i = 0
# Find the first non-zero word from right-to-left. (a != 0) # Find the first non-zero word from right-to-left. (a != 0)
while i < M.len-1: while i < M.len-1:
if bool(M[i] != Zero): if bool(M[i] != Zero):
break break
i += 1 i += 1
let ctz = int(countTrailingZeroBits_vartime(BaseType M[i])) + int(countTrailingZeroBits_vartime(BaseType M[i])) +
WordBitWidth*i WordBitWidth*i
# Even modulus: power of two (mod 2ᵏ) # Even modulus: power of two (mod 2ᵏ)

View File

@ -112,9 +112,18 @@ func powMod2k_vartime*(
r[0] = One # x⁰ = 1, even for 0⁰ r[0] = One # x⁰ = 1, even for 0⁰
return return
if a.isEven().bool and # if a is even if a.isEven().bool:
1+msb >= k.int: # The msb of a n-bit integer is at n-1 let aTrailingZeroes = block:
return # r ≡ aᵉ (mod 2ᵏ) ≡ (2b)ᵏ⁺ⁿ (mod 2ᵏ) ≡ 2ᵏ.2ⁿ.bᵏ⁺ⁿ (mod 2ᵏ) ≡ 0 (mod 2ᵏ) var i = 0
while i < a.len-1:
if bool(a[i] != Zero):
break
i += 1
int(countTrailingZeroBits_vartime(BaseType a[i])) +
WordBitWidth*i
# if a is even, a = 2b and if e > k then there exists n such that e = k+n
if aTrailingZeroes+msb >= k.int: # r ≡ aᵉ (mod 2ᵏ) ≡ (2b)ᵏ⁺ⁿ (mod 2ᵏ) ≡ 2ᵏ.2ⁿ.bᵏ⁺ⁿ (mod 2ᵏ) ≡ 0 (mod 2ᵏ)
return # we can generalize to a = 2ᵗᶻb with tz the number of trailing zeros.
var bitsLeft = msb+1 var bitsLeft = msb+1
if a.isOdd().bool and # if a is odd if a.isOdd().bool and # if a is odd
@ -130,7 +139,7 @@ func powMod2k_vartime*(
# range [r.len, a.len) will be truncated (mod 2ᵏ) # range [r.len, a.len) will be truncated (mod 2ᵏ)
sBuf[i] = a[i] sBuf[i] = a[i]
# TODO: sliding window # TODO: sliding/fixed window exponentiation
for i in countdown(exponent.len-1, 0): for i in countdown(exponent.len-1, 0):
for bit in unpackLE(exponent[i]): for bit in unpackLE(exponent[i]):
if bit: if bit: