Accelerate eth_evm_modexp by 25x by dividing input size by 8 (#249)

* Accelerate eth_evm_modexp by 25x by dividing input size by 8 (scales quadratically)

* instant exponentiation by power of 2 depending on trailing zeroes

* improve bench report

* rename

* rewrite the pow2k even/trailingZero accel

* eth_evm_modexp: remove leftover TimeEffect
This commit is contained in:
Mamy Ratsimbazafy 2023-07-03 01:45:36 +02:00 committed by GitHub
parent d0f4ad8cda
commit b7687ddc4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 115 additions and 32 deletions

View File

@ -0,0 +1,69 @@
import
../constantine/ethereum_evm_precompiles,
./platforms, ./bench_blueprint,
../constantine/platforms/codecs
proc report(op: string, elapsedNs: int64, elapsedCycles: int64, iters: int) =
let ns = elapsedNs div iters
let cycles = elapsedCycles div iters
let throughput = 1e9 / float64(ns)
when SupportsGetTicks:
echo &"{op:<45} {throughput:>15.3f} ops/s {ns:>16} ns/op {cycles:>12} CPU cycles (approx)"
else:
echo &"{op:<45} {throughput:>15.3f} ops/s {ns:>16} ns/op"
template bench(fnCall: untyped, ticks, ns: var int64): untyped =
block:
let startTime = getMonotime()
let startClock = getTicks()
fnCall
let stopClock = getTicks()
let stopTime = getMonotime()
ticks += stopClock - startClock
ns += inNanoseconds(stopTime-startTime)
proc main() =
let input = [
# Length of base (32)
(uint8)0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20,
# Length of exponent (32)
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20,
# Length of modulus (32)
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20,
# Base (96064778440517843452771003943013638877275214272712651271554889917016327417616)
0xd4, 0x62, 0xbc, 0xde, 0x8f, 0x57, 0xb0, 0x4a, 0x3f, 0xe1, 0x16, 0xc8, 0x12, 0x8c, 0x44, 0x34,
0xcf, 0x10, 0x25, 0x2e, 0x48, 0xa3, 0xcc, 0x0d, 0x28, 0xdf, 0x2b, 0xac, 0x4a, 0x8d, 0x6f, 0x10,
# Exponent (96064778440517843452771003943013638877275214272712651271554889917016327417616)
0xd4, 0x62, 0xbc, 0xde, 0x8f, 0x57, 0xb0, 0x4a, 0x3f, 0xe1, 0x16, 0xc8, 0x12, 0x8c, 0x44, 0x34,
0xcf, 0x10, 0x25, 0x2e, 0x48, 0xa3, 0xcc, 0x0d, 0x28, 0xdf, 0x2b, 0xac, 0x4a, 0x8d, 0x6f, 0x10,
# Modulus (57896044618658097711785492504343953926634992332820282019728792003956564819968)
0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
]
var r = newSeq[byte](32)
var ticks, nanoseconds: int64
const Iters = 22058
for i in 0 ..< Iters:
bench(
(let _ = r.eth_evm_modexp(input)),
ticks, nanoseconds)
report("EVM Modexp", nanoseconds, ticks, Iters)
echo "Total time: ", nanoseconds.float64 / 1e6, " ms"
main()

View File

@ -47,7 +47,7 @@ from bigints import nil # force qualified import to avoid conflicts on BigInt
proc report(op: string, elapsedNs: int64, elapsedCycles: int64, iters: int) =
let ns = elapsedNs div iters
let cycles = elapsedCycles div iters
let throughput = 1e9 / float64(elapsedNs)
let throughput = 1e9 / float64(ns)
when SupportsGetTicks:
echo &"{op:<45} {throughput:>15.3f} ops/s {ns:>16} ns/op {cycles:>12} CPU cycles (approx)"
else:

View File

@ -423,49 +423,53 @@ func eth_evm_modexp*(r: var openArray[byte], inputs: openArray[byte]): CttEVMSta
return cttEVM_InvalidInputSize
let
baseLen = cast[int](bL.limbs[0])
exponentLen = cast[int](eL.limbs[0])
modulusLen = cast[int](mL.limbs[0])
baseByteLen = cast[int](bL.limbs[0])
exponentByteLen = cast[int](eL.limbs[0])
modulusByteLen = cast[int](mL.limbs[0])
if r.len != modulusLen:
baseWordLen = baseByteLen.ceilDiv_vartime(WordBitWidth div 8)
modulusWordLen = modulusByteLen.ceilDiv_vartime(WordBitWidth div 8)
if r.len != modulusByteLen:
return cttEVM_InvalidOutputSize
if baseLen.ceilDiv_vartime(WordBitWidth div 8) > modulusLen.ceilDiv_vartime(WordBitWidth div 8):
if baseWordLen > modulusWordLen:
return cttEVM_InvalidInputSize
# Special cases
# ----------------------
if modulusLen == 0:
if modulusByteLen == 0:
r.setZero()
return cttEVM_Success
if exponentLen == 0:
if exponentByteLen == 0:
r.setZero()
r[r.len-1] = byte 1 # 0^0 = 1 and x^0 = 1
return cttEVM_Success
if baseLen == 0:
if baseByteLen == 0:
r.setZero()
return cttEVM_Success
# Input deserialization
# ---------------------
var baseBuf = allocStackArray(SecretWord, baseLen)
var modulusBuf = allocStackArray(SecretWord, modulusLen)
var outputBuf = allocStackArray(SecretWord, modulusLen)
template base(): untyped = baseBuf.toOpenArray(0, baseLen-1)
template modulus(): untyped = modulusBuf.toOpenArray(0, modulusLen-1)
template output(): untyped = outputBuf.toOpenArray(0, modulusLen-1)
# Inclusive stops
let baseStart = 96
let baseStop = baseStart+baseLen-1
let baseStop = baseStart+baseByteLen-1
let expStart = baseStop+1
let expStop = expStart+exponentLen-1
let expStop = expStart+exponentByteLen-1
let modStart = expStop+1
let modStop = modStart+modulusLen-1
let modStop = modStart+modulusByteLen-1
base.toOpenArray(0, baseLen-1).unmarshal(inputs.toOpenArray(baseStart, baseStop), WordBitWidth, bigEndian)
modulus.toOpenArray(0, modulusLen-1).unmarshal(inputs.toOpenArray(modStart, modStop), WordBitWidth, bigEndian)
var baseBuf = allocStackArray(SecretWord, baseWordLen)
var modulusBuf = allocStackArray(SecretWord, modulusWordLen)
var outputBuf = allocStackArray(SecretWord, modulusWordLen)
template base(): untyped = baseBuf.toOpenArray(0, baseWordLen-1)
template modulus(): untyped = modulusBuf.toOpenArray(0, modulusWordLen-1)
template output(): untyped = outputBuf.toOpenArray(0, modulusWordLen-1)
base.toOpenArray(0, baseWordLen-1).unmarshal(inputs.toOpenArray(baseStart, baseStop), WordBitWidth, bigEndian)
modulus.toOpenArray(0, modulusWordLen-1).unmarshal(inputs.toOpenArray(modStart, modStop), WordBitWidth, bigEndian)
template exponent(): untyped =
inputs.toOpenArray(expStart, expStop)

View File

@ -131,15 +131,16 @@ func powMod_vartime*(
# Even modulus
# -------------------------------------------------------------------
var i = 0
let ctz = block:
var i = 0
# Find the first non-zero word from right-to-left. (a != 0)
while i < M.len-1:
if bool(M[i] != Zero):
break
i += 1
# Find the first non-zero word from right-to-left. (a != 0)
while i < M.len-1:
if bool(M[i] != Zero):
break
i += 1
let ctz = int(countTrailingZeroBits_vartime(BaseType M[i])) +
int(countTrailingZeroBits_vartime(BaseType M[i])) +
WordBitWidth*i
# Even modulus: power of two (mod 2ᵏ)

View File

@ -112,9 +112,18 @@ func powMod2k_vartime*(
r[0] = One # x⁰ = 1, even for 0⁰
return
if a.isEven().bool and # if a is even
1+msb >= k.int: # The msb of a n-bit integer is at n-1
return # r ≡ aᵉ (mod 2ᵏ) ≡ (2b)ᵏ⁺ⁿ (mod 2ᵏ) ≡ 2ᵏ.2ⁿ.bᵏ⁺ⁿ (mod 2ᵏ) ≡ 0 (mod 2ᵏ)
if a.isEven().bool:
let aTrailingZeroes = block:
var i = 0
while i < a.len-1:
if bool(a[i] != Zero):
break
i += 1
int(countTrailingZeroBits_vartime(BaseType a[i])) +
WordBitWidth*i
# if a is even, a = 2b and if e > k then there exists n such that e = k+n
if aTrailingZeroes+msb >= k.int: # r ≡ aᵉ (mod 2ᵏ) ≡ (2b)ᵏ⁺ⁿ (mod 2ᵏ) ≡ 2ᵏ.2ⁿ.bᵏ⁺ⁿ (mod 2ᵏ) ≡ 0 (mod 2ᵏ)
return # we can generalize to a = 2ᵗᶻb with tz the number of trailing zeros.
var bitsLeft = msb+1
if a.isOdd().bool and # if a is odd
@ -130,7 +139,7 @@ func powMod2k_vartime*(
# range [r.len, a.len) will be truncated (mod 2ᵏ)
sBuf[i] = a[i]
# TODO: sliding window
# TODO: sliding/fixed window exponentiation
for i in countdown(exponent.len-1, 0):
for bit in unpackLE(exponent[i]):
if bit: