Fix lowest-hanging fruit in VM (#2382)

* replace set with bitseq for code validity test
* remove unusued code from CodeStream
* avoid unnecessary byte-by-byte copies
This commit is contained in:
Jacek Sieka 2024-06-18 02:55:35 +02:00 committed by GitHub
parent 135ef222a2
commit 8926da02b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 67 additions and 91 deletions

View File

@ -613,7 +613,7 @@ proc clearEmptyAccounts(ac: AccountsLedgerRef) =
ac.ripemdSpecial = false ac.ripemdSpecial = false
proc persist*(ac: AccountsLedgerRef, proc persist*(ac: AccountsLedgerRef,
clearEmptyAccount: bool = false) {.deprecated.} = clearEmptyAccount: bool = false) =
# make sure all savepoint already committed # make sure all savepoint already committed
doAssert(ac.savePoint.parentSavepoint.isNil) doAssert(ac.savePoint.parentSavepoint.isNil)

View File

@ -6,7 +6,7 @@
# at your option. This file may not be copied, modified, or distributed except according to those terms. # at your option. This file may not be copied, modified, or distributed except according to those terms.
import import
chronicles, strformat, strutils, sequtils, parseutils, sets, macros, chronicles, strformat, strutils, sequtils, parseutils,
eth/common, eth/common,
./interpreter/op_codes ./interpreter/op_codes
@ -17,20 +17,25 @@ type
CodeStream* = ref object CodeStream* = ref object
bytes*: seq[byte] bytes*: seq[byte]
depthProcessed: int depthProcessed: int
invalidPositions: HashSet[int] invalidPositions: seq[byte] # bit seq of invalid jump positions
pc*: int pc*: int
cached: seq[(int, Op, string)]
proc `$`*(b: byte): string = proc `$`*(b: byte): string =
$(b.int) $(b.int)
template bitpos(pos: int): (int, byte) =
(pos shr 3, 1'u8 shl (pos and 0x07))
proc newCodeStream*(codeBytes: sink seq[byte]): CodeStream = proc newCodeStream*(codeBytes: sink seq[byte]): CodeStream =
new(result) new(result)
result.bytes = system.move(codeBytes) result.bytes = system.move(codeBytes)
result.pc = 0 result.pc = 0
result.invalidPositions = HashSet[int]() result.invalidPositions = newSeq[byte]((result.bytes.len + 7) div 8)
result.depthProcessed = 0 result.depthProcessed = 0
result.cached = @[]
proc invalidPosition(c: CodeStream, pos: int): bool =
let (bpos, bbit) = bitpos(pos)
(c.invalidPositions[bpos] and bbit) > 0
proc newCodeStream*(codeBytes: string): CodeStream = proc newCodeStream*(codeBytes: string): CodeStream =
newCodeStream(codeBytes.mapIt(it.byte)) newCodeStream(codeBytes.mapIt(it.byte))
@ -45,14 +50,15 @@ proc newCodeStreamFromUnescaped*(code: string): CodeStream =
codeBytes.add(value.byte) codeBytes.add(value.byte)
newCodeStream(codeBytes) newCodeStream(codeBytes)
proc read*(c: var CodeStream, size: int): seq[byte] = template read*(c: CodeStream, size: int): openArray[byte] =
# TODO: use openArray[bytes] # TODO: use openArray[bytes]
if c.pc + size - 1 < c.bytes.len: if c.pc + size - 1 < c.bytes.len:
result = c.bytes[c.pc .. c.pc + size - 1] let pos = c.pc
c.pc += size c.pc += size
c.bytes.toOpenArray(pos, pos + size - 1)
else: else:
result = @[]
c.pc = c.bytes.len c.pc = c.bytes.len
c.bytes.toOpenArray(0, -1)
proc readVmWord*(c: var CodeStream, n: int): UInt256 = proc readVmWord*(c: var CodeStream, n: int): UInt256 =
## Reads `n` bytes from the code stream and pads ## Reads `n` bytes from the code stream and pads
@ -72,11 +78,11 @@ proc next*(c: var CodeStream): Op =
result = Op(c.bytes[c.pc]) result = Op(c.bytes[c.pc])
inc c.pc inc c.pc
else: else:
result = Stop result = Op.Stop
iterator items*(c: var CodeStream): Op = iterator items*(c: var CodeStream): Op =
var nextOpcode = c.next() var nextOpcode = c.next()
while nextOpcode != Op.STOP: while nextOpcode != Op.Stop:
yield nextOpcode yield nextOpcode
nextOpcode = c.next() nextOpcode = c.next()
@ -85,76 +91,50 @@ proc `[]`*(c: CodeStream, offset: int): Op =
proc peek*(c: var CodeStream): Op = proc peek*(c: var CodeStream): Op =
if c.pc < c.bytes.len: if c.pc < c.bytes.len:
result = Op(c.bytes[c.pc]) Op(c.bytes[c.pc])
else: else:
result = Stop Op.Stop
proc updatePc*(c: var CodeStream, value: int) = proc updatePc*(c: var CodeStream, value: int) =
c.pc = min(value, len(c)) c.pc = min(value, len(c))
when false:
template seek*(cs: var CodeStream, pc: int, handler: untyped): untyped =
var anchorPc = cs.pc
cs.pc = pc
try:
var c {.inject.} = cs
handler
finally:
cs.pc = anchorPc
proc isValidOpcode*(c: CodeStream, position: int): bool = proc isValidOpcode*(c: CodeStream, position: int): bool =
if position >= len(c): if position >= len(c):
return false false
if position in c.invalidPositions: elif c.invalidPosition(position):
return false false
if position <= c.depthProcessed: elif position <= c.depthProcessed:
return true true
else: else:
var i = c.depthProcessed var i = c.depthProcessed
while i <= position: while i <= position:
var opcode = Op(c[i]) var opcode = Op(c[i])
if opcode >= Op.PUSH1 and opcode <= Op.PUSH32: if opcode >= Op.Push1 and opcode <= Op.Push32:
var leftBound = (i + 1) var leftBound = (i + 1)
var rightBound = leftBound + (opcode.int - 95) var rightBound = leftBound + (opcode.int - 95)
for z in leftBound ..< rightBound: for z in leftBound ..< rightBound:
c.invalidPositions.incl(z) let (bpos, bbit) = bitpos(z)
c.invalidPositions[bpos] = c.invalidPositions[bpos] or bbit
i = rightBound i = rightBound
else: else:
c.depthProcessed = i
i += 1 i += 1
if position in c.invalidPositions: c.depthProcessed = i - 1
return false
else:
return true
proc decompile*(original: var CodeStream): seq[(int, Op, string)] = not c.invalidPosition(position)
proc decompile*(original: CodeStream): seq[(int, Op, string)] =
# behave as https://etherscan.io/opcode-tool # behave as https://etherscan.io/opcode-tool
# TODO
if original.cached.len > 0:
return original.cached
result = @[]
var c = newCodeStream(original.bytes) var c = newCodeStream(original.bytes)
while true: while true:
var op = c.next var op = c.next
if op >= Push1 and op <= Push32: if op >= Push1 and op <= Push32:
let bytes = c.read(op.int - 95) result.add(
result.add((c.pc - 1, op, "0x" & bytes.mapIt($(it.BiggestInt.toHex(2))).join(""))) (c.pc - 1, op, "0x" & c.read(op.int - 95).mapIt($(it.BiggestInt.toHex(2))).join("")))
elif op != Op.Stop: elif op != Op.Stop:
result.add((c.pc - 1, op, "")) result.add((c.pc - 1, op, ""))
else: else:
result.add((-1, Op.Stop, "")) result.add((-1, Op.Stop, ""))
break break
original.cached = result
proc displayDecompiled*(c: CodeStream) =
var copy = c
let opcodes = copy.decompile()
for op in opcodes:
echo op[0], " ", op[1], " ", op[2]
proc hasSStore*(c: var CodeStream): bool =
let opcodes = c.decompile()
result = opcodes.anyIt(it[1] == Sstore)
proc atEnd*(c: CodeStream): bool = proc atEnd*(c: CodeStream): bool =
result = c.pc >= c.bytes.len c.pc >= c.bytes.len

View File

@ -276,7 +276,7 @@ proc callOp(k: var VmCtx): EvmResultVoid =
contractAddress: p.contractAddress, contractAddress: p.contractAddress,
codeAddress: p.codeAddress, codeAddress: p.codeAddress,
value: p.value, value: p.value,
data: cpt.memory.read(p.memInPos, p.memInLen), data: @(cpt.memory.read(p.memInPos, p.memInLen)),
flags: p.flags)) flags: p.flags))
ok() ok()
@ -353,7 +353,7 @@ proc callCodeOp(k: var VmCtx): EvmResultVoid =
contractAddress: p.contractAddress, contractAddress: p.contractAddress,
codeAddress: p.codeAddress, codeAddress: p.codeAddress,
value: p.value, value: p.value,
data: cpt.memory.read(p.memInPos, p.memInLen), data: @(cpt.memory.read(p.memInPos, p.memInLen)),
flags: p.flags)) flags: p.flags))
ok() ok()
@ -425,7 +425,7 @@ proc delegateCallOp(k: var VmCtx): EvmResultVoid =
contractAddress: p.contractAddress, contractAddress: p.contractAddress,
codeAddress: p.codeAddress, codeAddress: p.codeAddress,
value: p.value, value: p.value,
data: cpt.memory.read(p.memInPos, p.memInLen), data: @(cpt.memory.read(p.memInPos, p.memInLen)),
flags: p.flags)) flags: p.flags))
ok() ok()
@ -498,7 +498,7 @@ proc staticCallOp(k: var VmCtx): EvmResultVoid =
contractAddress: p.contractAddress, contractAddress: p.contractAddress,
codeAddress: p.codeAddress, codeAddress: p.codeAddress,
value: p.value, value: p.value,
data: cpt.memory.read(p.memInPos, p.memInLen), data: @(cpt.memory.read(p.memInPos, p.memInLen)),
flags: p.flags)) flags: p.flags))
ok() ok()

View File

@ -156,7 +156,7 @@ proc createOp(k: var VmCtx): EvmResultVoid =
gas: createMsgGas, gas: createMsgGas,
sender: cpt.msg.contractAddress, sender: cpt.msg.contractAddress,
value: endowment, value: endowment,
data: cpt.memory.read(memPos, memLen))) data: @(cpt.memory.read(memPos, memLen))))
ok() ok()
# --------------------- # ---------------------
@ -240,7 +240,7 @@ proc create2Op(k: var VmCtx): EvmResultVoid =
gas: createMsgGas, gas: createMsgGas,
sender: cpt.msg.contractAddress, sender: cpt.msg.contractAddress,
value: endowment, value: endowment,
data: cpt.memory.read(memPos, memLen))) data: @(cpt.memory.read(memPos, memLen))))
ok() ok()
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------

View File

@ -15,6 +15,7 @@
{.push raises: [].} {.push raises: [].}
import import
stew/assign2,
../../../constants, ../../../constants,
../../evm_errors, ../../evm_errors,
../../computation, ../../computation,
@ -77,7 +78,7 @@ proc logImpl(c: Computation, opcode: Op, topicCount: int): EvmResultVoid =
let topic = ? c.stack.popTopic() let topic = ? c.stack.popTopic()
log.topics.add topic log.topics.add topic
log.data = c.memory.read(memPos, len) assign(log.data, c.memory.read(memPos, len))
log.address = c.msg.contractAddress log.address = c.msg.contractAddress
c.addLogEntry(log) c.addLogEntry(log)

View File

@ -112,7 +112,7 @@ func jumpImpl(c: Computation; jumpTarget: UInt256): EvmResultVoid =
if nextOpcode != JumpDest: if nextOpcode != JumpDest:
return err(opErr(InvalidJumpDest)) return err(opErr(InvalidJumpDest))
# TODO: next check seems redundant # Jump destination must be a valid opcode
if not c.code.isValidOpcode(jt): if not c.code.isValidOpcode(jt):
return err(opErr(InvalidJumpDest)) return err(opErr(InvalidJumpDest))

View File

@ -15,6 +15,7 @@
{.push raises: [].} {.push raises: [].}
import import
stew/assign2,
../../evm_errors, ../../evm_errors,
../../computation, ../../computation,
../../memory, ../../memory,
@ -46,7 +47,7 @@ proc returnOp(k: var VmCtx): EvmResultVoid =
k.cpt.gasCosts[Return].m_handler(k.cpt.memory.len, pos, len), k.cpt.gasCosts[Return].m_handler(k.cpt.memory.len, pos, len),
reason = "RETURN") reason = "RETURN")
k.cpt.memory.extend(pos, len) k.cpt.memory.extend(pos, len)
k.cpt.output = k.cpt.memory.read(pos, len) assign(k.cpt.output, k.cpt.memory.read(pos, len))
ok() ok()
@ -61,7 +62,7 @@ proc revertOp(k: var VmCtx): EvmResultVoid =
reason = "REVERT") reason = "REVERT")
k.cpt.memory.extend(pos, len) k.cpt.memory.extend(pos, len)
k.cpt.output = k.cpt.memory.read(pos, len) assign(k.cpt.output, k.cpt.memory.read(pos, len))
# setError(msg, false) will signal cheap revert # setError(msg, false) will signal cheap revert
k.cpt.setError(EVMC_REVERT, "REVERT opcode executed", false) k.cpt.setError(EVMC_REVERT, "REVERT opcode executed", false)
ok() ok()

View File

@ -11,6 +11,7 @@
{.push raises: [].} {.push raises: [].}
import import
stew/assign2,
./evm_errors, ./evm_errors,
./interpreter/utils/utils_numeric ./interpreter/utils/utils_numeric
@ -20,13 +21,12 @@ type
func new*(_: type EvmMemoryRef): EvmMemoryRef = func new*(_: type EvmMemoryRef): EvmMemoryRef =
new(result) new(result)
result.bytes = @[]
func len*(memory: EvmMemoryRef): int = func len*(memory: EvmMemoryRef): int =
result = memory.bytes.len memory.bytes.len
func extend*(memory: EvmMemoryRef; startPos: Natural; size: Natural) = func extend*(memory: EvmMemoryRef; startPos, size: int) =
if size == 0: if size <= 0:
return return
let newSize = ceil32(startPos + size) let newSize = ceil32(startPos + size)
if newSize <= len(memory): if newSize <= len(memory):
@ -37,12 +37,10 @@ func new*(_: type EvmMemoryRef, size: Natural): EvmMemoryRef =
result = EvmMemoryRef.new() result = EvmMemoryRef.new()
result.extend(0, size) result.extend(0, size)
func read*(memory: EvmMemoryRef, startPos: Natural, size: Natural): seq[byte] = template read*(memory: EvmMemoryRef, startPos, size: int): openArray[byte] =
result = newSeq[byte](size) memory.bytes.toOpenArray(startPos, startPos + size - 1)
if size > 0:
copyMem(result[0].addr, memory.bytes[startPos].addr, size)
template read32Bytes*(memory: EvmMemoryRef, startPos): auto = template read32Bytes*(memory: EvmMemoryRef, startPos: int): openArray[byte] =
memory.bytes.toOpenArray(startPos, startPos + 31) memory.bytes.toOpenArray(startPos, startPos + 31)
when defined(evmc_enabled): when defined(evmc_enabled):
@ -56,8 +54,8 @@ func write*(memory: EvmMemoryRef, startPos: Natural, value: openArray[byte]): Ev
return return
if startPos + size > memory.len: if startPos + size > memory.len:
return err(memErr(MemoryFull)) return err(memErr(MemoryFull))
for z, b in value:
memory.bytes[z + startPos] = b assign(memory.bytes.toOpenArray(startPos, int(startPos + size) - 1), value)
ok() ok()
func write*(memory: EvmMemoryRef, startPos: Natural, value: byte): EvmResultVoid = func write*(memory: EvmMemoryRef, startPos: Natural, value: byte): EvmResultVoid =
@ -71,19 +69,16 @@ func copy*(memory: EvmMemoryRef, dst, src, len: Natural) =
memory.extend(max(dst, src), len) memory.extend(max(dst, src), len)
if dst == src: if dst == src:
return return
elif dst < src: assign(
for i in 0..<len: memory.bytes.toOpenArray(dst, dst + len - 1),
memory.bytes[dst+i] = memory.bytes[src+i] memory.bytes.toOpenArray(src, src + len - 1))
else: # src > dst
for i in countdown(len-1, 0):
memory.bytes[dst+i] = memory.bytes[src+i]
func writePadded*(memory: EvmMemoryRef, data: openArray[byte], func writePadded*(memory: EvmMemoryRef, data: openArray[byte],
memPos, dataPos, len: Natural) = memPos, dataPos, len: Natural) =
memory.extend(memPos, len) memory.extend(memPos, len)
let let
dataEndPos = dataPos.int64 + len dataEndPos = dataPos + len
dataStart = min(dataPos, data.len) dataStart = min(dataPos, data.len)
dataEnd = min(data.len, dataEndPos) dataEnd = min(data.len, dataEndPos)
dataLen = dataEnd - dataStart dataLen = dataEnd - dataStart
@ -95,10 +90,10 @@ func writePadded*(memory: EvmMemoryRef, data: openArray[byte],
di = dataStart di = dataStart
mi = memPos mi = memPos
while di < dataEnd: assign(
memory.bytes[mi] = data[di] memory.bytes.toOpenArray(mi, mi + dataLen - 1),
inc di data.toOpenArray(di, di + dataLen - 1))
inc mi mi += dataLen
# although memory.extend already pad new block of memory # although memory.extend already pad new block of memory
# with zeros, it can be rewrite by some opcode # with zeros, it can be rewrite by some opcode

View File

@ -45,9 +45,9 @@ proc memoryMain*() =
test "read returns correct bytes": test "read returns correct bytes":
var mem = memory32() var mem = memory32()
check mem.write(startPos = 5, value = @[1.byte, 0.byte, 1.byte, 0.byte]).isOk check mem.write(startPos = 5, value = @[1.byte, 0.byte, 1.byte, 0.byte]).isOk
check(mem.read(startPos = 5, size = 4) == @[1.byte, 0.byte, 1.byte, 0.byte]) check(@(mem.read(startPos = 5, size = 4)) == @[1.byte, 0.byte, 1.byte, 0.byte])
check(mem.read(startPos = 6, size = 4) == @[0.byte, 1.byte, 0.byte, 0.byte]) check(@(mem.read(startPos = 6, size = 4)) == @[0.byte, 1.byte, 0.byte, 0.byte])
check(mem.read(startPos = 1, size = 3) == @[0.byte, 0.byte, 0.byte]) check(@(mem.read(startPos = 1, size = 3)) == @[0.byte, 0.byte, 0.byte])
when isMainModule: when isMainModule:
memoryMain() memoryMain()

View File

@ -13,7 +13,7 @@ import
unittest2, unittest2,
eth/rlp, eth/rlp,
./test_helpers, ./test_helpers,
../nimbus/[errors, transaction, evm/types], ../nimbus/[errors, transaction],
../nimbus/utils/utils ../nimbus/utils/utils
const const

View File

@ -19,7 +19,6 @@ import
../nimbus/core/casper, ../nimbus/core/casper,
../nimbus/common/common, ../nimbus/common/common,
../nimbus/utils/utils, ../nimbus/utils/utils,
../nimbus/evm/types,
./test_txpool/helpers, ./test_txpool/helpers,
./macro_assembler ./macro_assembler