Fix lowest-hanging fruit in VM (#2382)

* replace set with bitseq for code validity test
* remove unusued code from CodeStream
* avoid unnecessary byte-by-byte copies
This commit is contained in:
Jacek Sieka 2024-06-18 02:55:35 +02:00 committed by GitHub
parent 135ef222a2
commit 8926da02b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 67 additions and 91 deletions

View File

@ -613,7 +613,7 @@ proc clearEmptyAccounts(ac: AccountsLedgerRef) =
ac.ripemdSpecial = false
proc persist*(ac: AccountsLedgerRef,
clearEmptyAccount: bool = false) {.deprecated.} =
clearEmptyAccount: bool = false) =
# make sure all savepoint already committed
doAssert(ac.savePoint.parentSavepoint.isNil)

View File

@ -6,7 +6,7 @@
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
chronicles, strformat, strutils, sequtils, parseutils, sets, macros,
chronicles, strformat, strutils, sequtils, parseutils,
eth/common,
./interpreter/op_codes
@ -17,20 +17,25 @@ type
CodeStream* = ref object
bytes*: seq[byte]
depthProcessed: int
invalidPositions: HashSet[int]
invalidPositions: seq[byte] # bit seq of invalid jump positions
pc*: int
cached: seq[(int, Op, string)]
proc `$`*(b: byte): string =
$(b.int)
template bitpos(pos: int): (int, byte) =
(pos shr 3, 1'u8 shl (pos and 0x07))
proc newCodeStream*(codeBytes: sink seq[byte]): CodeStream =
new(result)
result.bytes = system.move(codeBytes)
result.pc = 0
result.invalidPositions = HashSet[int]()
result.invalidPositions = newSeq[byte]((result.bytes.len + 7) div 8)
result.depthProcessed = 0
result.cached = @[]
proc invalidPosition(c: CodeStream, pos: int): bool =
let (bpos, bbit) = bitpos(pos)
(c.invalidPositions[bpos] and bbit) > 0
proc newCodeStream*(codeBytes: string): CodeStream =
newCodeStream(codeBytes.mapIt(it.byte))
@ -45,14 +50,15 @@ proc newCodeStreamFromUnescaped*(code: string): CodeStream =
codeBytes.add(value.byte)
newCodeStream(codeBytes)
proc read*(c: var CodeStream, size: int): seq[byte] =
template read*(c: CodeStream, size: int): openArray[byte] =
# TODO: use openArray[bytes]
if c.pc + size - 1 < c.bytes.len:
result = c.bytes[c.pc .. c.pc + size - 1]
let pos = c.pc
c.pc += size
c.bytes.toOpenArray(pos, pos + size - 1)
else:
result = @[]
c.pc = c.bytes.len
c.bytes.toOpenArray(0, -1)
proc readVmWord*(c: var CodeStream, n: int): UInt256 =
## Reads `n` bytes from the code stream and pads
@ -72,11 +78,11 @@ proc next*(c: var CodeStream): Op =
result = Op(c.bytes[c.pc])
inc c.pc
else:
result = Stop
result = Op.Stop
iterator items*(c: var CodeStream): Op =
var nextOpcode = c.next()
while nextOpcode != Op.STOP:
while nextOpcode != Op.Stop:
yield nextOpcode
nextOpcode = c.next()
@ -85,76 +91,50 @@ proc `[]`*(c: CodeStream, offset: int): Op =
proc peek*(c: var CodeStream): Op =
if c.pc < c.bytes.len:
result = Op(c.bytes[c.pc])
Op(c.bytes[c.pc])
else:
result = Stop
Op.Stop
proc updatePc*(c: var CodeStream, value: int) =
c.pc = min(value, len(c))
when false:
template seek*(cs: var CodeStream, pc: int, handler: untyped): untyped =
var anchorPc = cs.pc
cs.pc = pc
try:
var c {.inject.} = cs
handler
finally:
cs.pc = anchorPc
proc isValidOpcode*(c: CodeStream, position: int): bool =
if position >= len(c):
return false
if position in c.invalidPositions:
return false
if position <= c.depthProcessed:
return true
false
elif c.invalidPosition(position):
false
elif position <= c.depthProcessed:
true
else:
var i = c.depthProcessed
while i <= position:
var opcode = Op(c[i])
if opcode >= Op.PUSH1 and opcode <= Op.PUSH32:
if opcode >= Op.Push1 and opcode <= Op.Push32:
var leftBound = (i + 1)
var rightBound = leftBound + (opcode.int - 95)
for z in leftBound ..< rightBound:
c.invalidPositions.incl(z)
let (bpos, bbit) = bitpos(z)
c.invalidPositions[bpos] = c.invalidPositions[bpos] or bbit
i = rightBound
else:
c.depthProcessed = i
i += 1
if position in c.invalidPositions:
return false
else:
return true
c.depthProcessed = i - 1
proc decompile*(original: var CodeStream): seq[(int, Op, string)] =
not c.invalidPosition(position)
proc decompile*(original: CodeStream): seq[(int, Op, string)] =
# behave as https://etherscan.io/opcode-tool
# TODO
if original.cached.len > 0:
return original.cached
result = @[]
var c = newCodeStream(original.bytes)
while true:
var op = c.next
if op >= Push1 and op <= Push32:
let bytes = c.read(op.int - 95)
result.add((c.pc - 1, op, "0x" & bytes.mapIt($(it.BiggestInt.toHex(2))).join("")))
result.add(
(c.pc - 1, op, "0x" & c.read(op.int - 95).mapIt($(it.BiggestInt.toHex(2))).join("")))
elif op != Op.Stop:
result.add((c.pc - 1, op, ""))
else:
result.add((-1, Op.Stop, ""))
break
original.cached = result
proc displayDecompiled*(c: CodeStream) =
var copy = c
let opcodes = copy.decompile()
for op in opcodes:
echo op[0], " ", op[1], " ", op[2]
proc hasSStore*(c: var CodeStream): bool =
let opcodes = c.decompile()
result = opcodes.anyIt(it[1] == Sstore)
proc atEnd*(c: CodeStream): bool =
result = c.pc >= c.bytes.len
c.pc >= c.bytes.len

View File

@ -276,7 +276,7 @@ proc callOp(k: var VmCtx): EvmResultVoid =
contractAddress: p.contractAddress,
codeAddress: p.codeAddress,
value: p.value,
data: cpt.memory.read(p.memInPos, p.memInLen),
data: @(cpt.memory.read(p.memInPos, p.memInLen)),
flags: p.flags))
ok()
@ -353,7 +353,7 @@ proc callCodeOp(k: var VmCtx): EvmResultVoid =
contractAddress: p.contractAddress,
codeAddress: p.codeAddress,
value: p.value,
data: cpt.memory.read(p.memInPos, p.memInLen),
data: @(cpt.memory.read(p.memInPos, p.memInLen)),
flags: p.flags))
ok()
@ -425,7 +425,7 @@ proc delegateCallOp(k: var VmCtx): EvmResultVoid =
contractAddress: p.contractAddress,
codeAddress: p.codeAddress,
value: p.value,
data: cpt.memory.read(p.memInPos, p.memInLen),
data: @(cpt.memory.read(p.memInPos, p.memInLen)),
flags: p.flags))
ok()
@ -498,7 +498,7 @@ proc staticCallOp(k: var VmCtx): EvmResultVoid =
contractAddress: p.contractAddress,
codeAddress: p.codeAddress,
value: p.value,
data: cpt.memory.read(p.memInPos, p.memInLen),
data: @(cpt.memory.read(p.memInPos, p.memInLen)),
flags: p.flags))
ok()

View File

@ -156,7 +156,7 @@ proc createOp(k: var VmCtx): EvmResultVoid =
gas: createMsgGas,
sender: cpt.msg.contractAddress,
value: endowment,
data: cpt.memory.read(memPos, memLen)))
data: @(cpt.memory.read(memPos, memLen))))
ok()
# ---------------------
@ -240,7 +240,7 @@ proc create2Op(k: var VmCtx): EvmResultVoid =
gas: createMsgGas,
sender: cpt.msg.contractAddress,
value: endowment,
data: cpt.memory.read(memPos, memLen)))
data: @(cpt.memory.read(memPos, memLen))))
ok()
# ------------------------------------------------------------------------------

View File

@ -15,6 +15,7 @@
{.push raises: [].}
import
stew/assign2,
../../../constants,
../../evm_errors,
../../computation,
@ -77,7 +78,7 @@ proc logImpl(c: Computation, opcode: Op, topicCount: int): EvmResultVoid =
let topic = ? c.stack.popTopic()
log.topics.add topic
log.data = c.memory.read(memPos, len)
assign(log.data, c.memory.read(memPos, len))
log.address = c.msg.contractAddress
c.addLogEntry(log)

View File

@ -112,7 +112,7 @@ func jumpImpl(c: Computation; jumpTarget: UInt256): EvmResultVoid =
if nextOpcode != JumpDest:
return err(opErr(InvalidJumpDest))
# TODO: next check seems redundant
# Jump destination must be a valid opcode
if not c.code.isValidOpcode(jt):
return err(opErr(InvalidJumpDest))

View File

@ -15,6 +15,7 @@
{.push raises: [].}
import
stew/assign2,
../../evm_errors,
../../computation,
../../memory,
@ -46,7 +47,7 @@ proc returnOp(k: var VmCtx): EvmResultVoid =
k.cpt.gasCosts[Return].m_handler(k.cpt.memory.len, pos, len),
reason = "RETURN")
k.cpt.memory.extend(pos, len)
k.cpt.output = k.cpt.memory.read(pos, len)
assign(k.cpt.output, k.cpt.memory.read(pos, len))
ok()
@ -61,7 +62,7 @@ proc revertOp(k: var VmCtx): EvmResultVoid =
reason = "REVERT")
k.cpt.memory.extend(pos, len)
k.cpt.output = k.cpt.memory.read(pos, len)
assign(k.cpt.output, k.cpt.memory.read(pos, len))
# setError(msg, false) will signal cheap revert
k.cpt.setError(EVMC_REVERT, "REVERT opcode executed", false)
ok()

View File

@ -11,6 +11,7 @@
{.push raises: [].}
import
stew/assign2,
./evm_errors,
./interpreter/utils/utils_numeric
@ -20,13 +21,12 @@ type
func new*(_: type EvmMemoryRef): EvmMemoryRef =
new(result)
result.bytes = @[]
func len*(memory: EvmMemoryRef): int =
result = memory.bytes.len
memory.bytes.len
func extend*(memory: EvmMemoryRef; startPos: Natural; size: Natural) =
if size == 0:
func extend*(memory: EvmMemoryRef; startPos, size: int) =
if size <= 0:
return
let newSize = ceil32(startPos + size)
if newSize <= len(memory):
@ -37,12 +37,10 @@ func new*(_: type EvmMemoryRef, size: Natural): EvmMemoryRef =
result = EvmMemoryRef.new()
result.extend(0, size)
func read*(memory: EvmMemoryRef, startPos: Natural, size: Natural): seq[byte] =
result = newSeq[byte](size)
if size > 0:
copyMem(result[0].addr, memory.bytes[startPos].addr, size)
template read*(memory: EvmMemoryRef, startPos, size: int): openArray[byte] =
memory.bytes.toOpenArray(startPos, startPos + size - 1)
template read32Bytes*(memory: EvmMemoryRef, startPos): auto =
template read32Bytes*(memory: EvmMemoryRef, startPos: int): openArray[byte] =
memory.bytes.toOpenArray(startPos, startPos + 31)
when defined(evmc_enabled):
@ -56,8 +54,8 @@ func write*(memory: EvmMemoryRef, startPos: Natural, value: openArray[byte]): Ev
return
if startPos + size > memory.len:
return err(memErr(MemoryFull))
for z, b in value:
memory.bytes[z + startPos] = b
assign(memory.bytes.toOpenArray(startPos, int(startPos + size) - 1), value)
ok()
func write*(memory: EvmMemoryRef, startPos: Natural, value: byte): EvmResultVoid =
@ -71,19 +69,16 @@ func copy*(memory: EvmMemoryRef, dst, src, len: Natural) =
memory.extend(max(dst, src), len)
if dst == src:
return
elif dst < src:
for i in 0..<len:
memory.bytes[dst+i] = memory.bytes[src+i]
else: # src > dst
for i in countdown(len-1, 0):
memory.bytes[dst+i] = memory.bytes[src+i]
assign(
memory.bytes.toOpenArray(dst, dst + len - 1),
memory.bytes.toOpenArray(src, src + len - 1))
func writePadded*(memory: EvmMemoryRef, data: openArray[byte],
memPos, dataPos, len: Natural) =
memory.extend(memPos, len)
let
dataEndPos = dataPos.int64 + len
dataEndPos = dataPos + len
dataStart = min(dataPos, data.len)
dataEnd = min(data.len, dataEndPos)
dataLen = dataEnd - dataStart
@ -95,10 +90,10 @@ func writePadded*(memory: EvmMemoryRef, data: openArray[byte],
di = dataStart
mi = memPos
while di < dataEnd:
memory.bytes[mi] = data[di]
inc di
inc mi
assign(
memory.bytes.toOpenArray(mi, mi + dataLen - 1),
data.toOpenArray(di, di + dataLen - 1))
mi += dataLen
# although memory.extend already pad new block of memory
# with zeros, it can be rewrite by some opcode

View File

@ -45,9 +45,9 @@ proc memoryMain*() =
test "read returns correct bytes":
var mem = memory32()
check mem.write(startPos = 5, value = @[1.byte, 0.byte, 1.byte, 0.byte]).isOk
check(mem.read(startPos = 5, size = 4) == @[1.byte, 0.byte, 1.byte, 0.byte])
check(mem.read(startPos = 6, size = 4) == @[0.byte, 1.byte, 0.byte, 0.byte])
check(mem.read(startPos = 1, size = 3) == @[0.byte, 0.byte, 0.byte])
check(@(mem.read(startPos = 5, size = 4)) == @[1.byte, 0.byte, 1.byte, 0.byte])
check(@(mem.read(startPos = 6, size = 4)) == @[0.byte, 1.byte, 0.byte, 0.byte])
check(@(mem.read(startPos = 1, size = 3)) == @[0.byte, 0.byte, 0.byte])
when isMainModule:
memoryMain()

View File

@ -13,7 +13,7 @@ import
unittest2,
eth/rlp,
./test_helpers,
../nimbus/[errors, transaction, evm/types],
../nimbus/[errors, transaction],
../nimbus/utils/utils
const

View File

@ -19,7 +19,6 @@ import
../nimbus/core/casper,
../nimbus/common/common,
../nimbus/utils/utils,
../nimbus/evm/types,
./test_txpool/helpers,
./macro_assembler