mirror of
https://github.com/logos-storage/constantine.git
synced 2026-01-07 07:33:08 +00:00
Accelerate eth_evm_modexp by 25x by dividing input size by 8 (#249)
* Accelerate eth_evm_modexp by 25x by dividing input size by 8 (scales quadratically) * instant exponentiation by power of 2 depending on trailing zeroes * improve bench report * rename * rewrite the pow2k even/trailingZero accel * eth_evm_modexp: remove leftover TimeEffect
This commit is contained in:
parent
d0f4ad8cda
commit
b7687ddc4a
69
benchmarks/bench_evm_modexp_dos.nim
Normal file
69
benchmarks/bench_evm_modexp_dos.nim
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import
|
||||||
|
../constantine/ethereum_evm_precompiles,
|
||||||
|
./platforms, ./bench_blueprint,
|
||||||
|
|
||||||
|
../constantine/platforms/codecs
|
||||||
|
|
||||||
|
proc report(op: string, elapsedNs: int64, elapsedCycles: int64, iters: int) =
|
||||||
|
let ns = elapsedNs div iters
|
||||||
|
let cycles = elapsedCycles div iters
|
||||||
|
let throughput = 1e9 / float64(ns)
|
||||||
|
when SupportsGetTicks:
|
||||||
|
echo &"{op:<45} {throughput:>15.3f} ops/s {ns:>16} ns/op {cycles:>12} CPU cycles (approx)"
|
||||||
|
else:
|
||||||
|
echo &"{op:<45} {throughput:>15.3f} ops/s {ns:>16} ns/op"
|
||||||
|
|
||||||
|
template bench(fnCall: untyped, ticks, ns: var int64): untyped =
|
||||||
|
block:
|
||||||
|
let startTime = getMonotime()
|
||||||
|
let startClock = getTicks()
|
||||||
|
fnCall
|
||||||
|
let stopClock = getTicks()
|
||||||
|
let stopTime = getMonotime()
|
||||||
|
|
||||||
|
ticks += stopClock - startClock
|
||||||
|
ns += inNanoseconds(stopTime-startTime)
|
||||||
|
|
||||||
|
proc main() =
|
||||||
|
|
||||||
|
let input = [
|
||||||
|
# Length of base (32)
|
||||||
|
(uint8)0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20,
|
||||||
|
|
||||||
|
# Length of exponent (32)
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20,
|
||||||
|
|
||||||
|
# Length of modulus (32)
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20,
|
||||||
|
|
||||||
|
# Base (96064778440517843452771003943013638877275214272712651271554889917016327417616)
|
||||||
|
0xd4, 0x62, 0xbc, 0xde, 0x8f, 0x57, 0xb0, 0x4a, 0x3f, 0xe1, 0x16, 0xc8, 0x12, 0x8c, 0x44, 0x34,
|
||||||
|
0xcf, 0x10, 0x25, 0x2e, 0x48, 0xa3, 0xcc, 0x0d, 0x28, 0xdf, 0x2b, 0xac, 0x4a, 0x8d, 0x6f, 0x10,
|
||||||
|
|
||||||
|
# Exponent (96064778440517843452771003943013638877275214272712651271554889917016327417616)
|
||||||
|
0xd4, 0x62, 0xbc, 0xde, 0x8f, 0x57, 0xb0, 0x4a, 0x3f, 0xe1, 0x16, 0xc8, 0x12, 0x8c, 0x44, 0x34,
|
||||||
|
0xcf, 0x10, 0x25, 0x2e, 0x48, 0xa3, 0xcc, 0x0d, 0x28, 0xdf, 0x2b, 0xac, 0x4a, 0x8d, 0x6f, 0x10,
|
||||||
|
|
||||||
|
# Modulus (57896044618658097711785492504343953926634992332820282019728792003956564819968)
|
||||||
|
0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
]
|
||||||
|
|
||||||
|
var r = newSeq[byte](32)
|
||||||
|
var ticks, nanoseconds: int64
|
||||||
|
|
||||||
|
const Iters = 22058
|
||||||
|
|
||||||
|
for i in 0 ..< Iters:
|
||||||
|
bench(
|
||||||
|
(let _ = r.eth_evm_modexp(input)),
|
||||||
|
ticks, nanoseconds)
|
||||||
|
|
||||||
|
report("EVM Modexp", nanoseconds, ticks, Iters)
|
||||||
|
echo "Total time: ", nanoseconds.float64 / 1e6, " ms"
|
||||||
|
|
||||||
|
main()
|
||||||
@ -47,7 +47,7 @@ from bigints import nil # force qualified import to avoid conflicts on BigInt
|
|||||||
proc report(op: string, elapsedNs: int64, elapsedCycles: int64, iters: int) =
|
proc report(op: string, elapsedNs: int64, elapsedCycles: int64, iters: int) =
|
||||||
let ns = elapsedNs div iters
|
let ns = elapsedNs div iters
|
||||||
let cycles = elapsedCycles div iters
|
let cycles = elapsedCycles div iters
|
||||||
let throughput = 1e9 / float64(elapsedNs)
|
let throughput = 1e9 / float64(ns)
|
||||||
when SupportsGetTicks:
|
when SupportsGetTicks:
|
||||||
echo &"{op:<45} {throughput:>15.3f} ops/s {ns:>16} ns/op {cycles:>12} CPU cycles (approx)"
|
echo &"{op:<45} {throughput:>15.3f} ops/s {ns:>16} ns/op {cycles:>12} CPU cycles (approx)"
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -423,49 +423,53 @@ func eth_evm_modexp*(r: var openArray[byte], inputs: openArray[byte]): CttEVMSta
|
|||||||
return cttEVM_InvalidInputSize
|
return cttEVM_InvalidInputSize
|
||||||
|
|
||||||
let
|
let
|
||||||
baseLen = cast[int](bL.limbs[0])
|
baseByteLen = cast[int](bL.limbs[0])
|
||||||
exponentLen = cast[int](eL.limbs[0])
|
exponentByteLen = cast[int](eL.limbs[0])
|
||||||
modulusLen = cast[int](mL.limbs[0])
|
modulusByteLen = cast[int](mL.limbs[0])
|
||||||
|
|
||||||
if r.len != modulusLen:
|
baseWordLen = baseByteLen.ceilDiv_vartime(WordBitWidth div 8)
|
||||||
|
modulusWordLen = modulusByteLen.ceilDiv_vartime(WordBitWidth div 8)
|
||||||
|
|
||||||
|
if r.len != modulusByteLen:
|
||||||
return cttEVM_InvalidOutputSize
|
return cttEVM_InvalidOutputSize
|
||||||
|
|
||||||
if baseLen.ceilDiv_vartime(WordBitWidth div 8) > modulusLen.ceilDiv_vartime(WordBitWidth div 8):
|
if baseWordLen > modulusWordLen:
|
||||||
return cttEVM_InvalidInputSize
|
return cttEVM_InvalidInputSize
|
||||||
|
|
||||||
# Special cases
|
# Special cases
|
||||||
# ----------------------
|
# ----------------------
|
||||||
if modulusLen == 0:
|
if modulusByteLen == 0:
|
||||||
r.setZero()
|
r.setZero()
|
||||||
return cttEVM_Success
|
return cttEVM_Success
|
||||||
if exponentLen == 0:
|
if exponentByteLen == 0:
|
||||||
r.setZero()
|
r.setZero()
|
||||||
r[r.len-1] = byte 1 # 0^0 = 1 and x^0 = 1
|
r[r.len-1] = byte 1 # 0^0 = 1 and x^0 = 1
|
||||||
return cttEVM_Success
|
return cttEVM_Success
|
||||||
if baseLen == 0:
|
if baseByteLen == 0:
|
||||||
r.setZero()
|
r.setZero()
|
||||||
return cttEVM_Success
|
return cttEVM_Success
|
||||||
|
|
||||||
# Input deserialization
|
# Input deserialization
|
||||||
# ---------------------
|
# ---------------------
|
||||||
var baseBuf = allocStackArray(SecretWord, baseLen)
|
|
||||||
var modulusBuf = allocStackArray(SecretWord, modulusLen)
|
|
||||||
var outputBuf = allocStackArray(SecretWord, modulusLen)
|
|
||||||
|
|
||||||
template base(): untyped = baseBuf.toOpenArray(0, baseLen-1)
|
|
||||||
template modulus(): untyped = modulusBuf.toOpenArray(0, modulusLen-1)
|
|
||||||
template output(): untyped = outputBuf.toOpenArray(0, modulusLen-1)
|
|
||||||
|
|
||||||
# Inclusive stops
|
# Inclusive stops
|
||||||
let baseStart = 96
|
let baseStart = 96
|
||||||
let baseStop = baseStart+baseLen-1
|
let baseStop = baseStart+baseByteLen-1
|
||||||
let expStart = baseStop+1
|
let expStart = baseStop+1
|
||||||
let expStop = expStart+exponentLen-1
|
let expStop = expStart+exponentByteLen-1
|
||||||
let modStart = expStop+1
|
let modStart = expStop+1
|
||||||
let modStop = modStart+modulusLen-1
|
let modStop = modStart+modulusByteLen-1
|
||||||
|
|
||||||
base.toOpenArray(0, baseLen-1).unmarshal(inputs.toOpenArray(baseStart, baseStop), WordBitWidth, bigEndian)
|
var baseBuf = allocStackArray(SecretWord, baseWordLen)
|
||||||
modulus.toOpenArray(0, modulusLen-1).unmarshal(inputs.toOpenArray(modStart, modStop), WordBitWidth, bigEndian)
|
var modulusBuf = allocStackArray(SecretWord, modulusWordLen)
|
||||||
|
var outputBuf = allocStackArray(SecretWord, modulusWordLen)
|
||||||
|
|
||||||
|
template base(): untyped = baseBuf.toOpenArray(0, baseWordLen-1)
|
||||||
|
template modulus(): untyped = modulusBuf.toOpenArray(0, modulusWordLen-1)
|
||||||
|
template output(): untyped = outputBuf.toOpenArray(0, modulusWordLen-1)
|
||||||
|
|
||||||
|
base.toOpenArray(0, baseWordLen-1).unmarshal(inputs.toOpenArray(baseStart, baseStop), WordBitWidth, bigEndian)
|
||||||
|
modulus.toOpenArray(0, modulusWordLen-1).unmarshal(inputs.toOpenArray(modStart, modStop), WordBitWidth, bigEndian)
|
||||||
template exponent(): untyped =
|
template exponent(): untyped =
|
||||||
inputs.toOpenArray(expStart, expStop)
|
inputs.toOpenArray(expStart, expStop)
|
||||||
|
|
||||||
|
|||||||
@ -131,15 +131,16 @@ func powMod_vartime*(
|
|||||||
# Even modulus
|
# Even modulus
|
||||||
# -------------------------------------------------------------------
|
# -------------------------------------------------------------------
|
||||||
|
|
||||||
var i = 0
|
let ctz = block:
|
||||||
|
var i = 0
|
||||||
|
|
||||||
# Find the first non-zero word from right-to-left. (a != 0)
|
# Find the first non-zero word from right-to-left. (a != 0)
|
||||||
while i < M.len-1:
|
while i < M.len-1:
|
||||||
if bool(M[i] != Zero):
|
if bool(M[i] != Zero):
|
||||||
break
|
break
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
let ctz = int(countTrailingZeroBits_vartime(BaseType M[i])) +
|
int(countTrailingZeroBits_vartime(BaseType M[i])) +
|
||||||
WordBitWidth*i
|
WordBitWidth*i
|
||||||
|
|
||||||
# Even modulus: power of two (mod 2ᵏ)
|
# Even modulus: power of two (mod 2ᵏ)
|
||||||
|
|||||||
@ -112,9 +112,18 @@ func powMod2k_vartime*(
|
|||||||
r[0] = One # x⁰ = 1, even for 0⁰
|
r[0] = One # x⁰ = 1, even for 0⁰
|
||||||
return
|
return
|
||||||
|
|
||||||
if a.isEven().bool and # if a is even
|
if a.isEven().bool:
|
||||||
1+msb >= k.int: # The msb of a n-bit integer is at n-1
|
let aTrailingZeroes = block:
|
||||||
return # r ≡ aᵉ (mod 2ᵏ) ≡ (2b)ᵏ⁺ⁿ (mod 2ᵏ) ≡ 2ᵏ.2ⁿ.bᵏ⁺ⁿ (mod 2ᵏ) ≡ 0 (mod 2ᵏ)
|
var i = 0
|
||||||
|
while i < a.len-1:
|
||||||
|
if bool(a[i] != Zero):
|
||||||
|
break
|
||||||
|
i += 1
|
||||||
|
int(countTrailingZeroBits_vartime(BaseType a[i])) +
|
||||||
|
WordBitWidth*i
|
||||||
|
# if a is even, a = 2b and if e > k then there exists n such that e = k+n
|
||||||
|
if aTrailingZeroes+msb >= k.int: # r ≡ aᵉ (mod 2ᵏ) ≡ (2b)ᵏ⁺ⁿ (mod 2ᵏ) ≡ 2ᵏ.2ⁿ.bᵏ⁺ⁿ (mod 2ᵏ) ≡ 0 (mod 2ᵏ)
|
||||||
|
return # we can generalize to a = 2ᵗᶻb with tz the number of trailing zeros.
|
||||||
|
|
||||||
var bitsLeft = msb+1
|
var bitsLeft = msb+1
|
||||||
if a.isOdd().bool and # if a is odd
|
if a.isOdd().bool and # if a is odd
|
||||||
@ -130,7 +139,7 @@ func powMod2k_vartime*(
|
|||||||
# range [r.len, a.len) will be truncated (mod 2ᵏ)
|
# range [r.len, a.len) will be truncated (mod 2ᵏ)
|
||||||
sBuf[i] = a[i]
|
sBuf[i] = a[i]
|
||||||
|
|
||||||
# TODO: sliding window
|
# TODO: sliding/fixed window exponentiation
|
||||||
for i in countdown(exponent.len-1, 0):
|
for i in countdown(exponent.len-1, 0):
|
||||||
for bit in unpackLE(exponent[i]):
|
for bit in unpackLE(exponent[i]):
|
||||||
if bit:
|
if bit:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user