diff --git a/nimbus/vm/precompiles.nim b/nimbus/vm/precompiles.nim index 36f71a174..08e82f499 100644 --- a/nimbus/vm/precompiles.nim +++ b/nimbus/vm/precompiles.nim @@ -178,16 +178,22 @@ proc modExpInternal(computation: Computation, baseLen, expLen, modLen: int, T: t computation.output = newSeq[byte](modLen) computation.output[^output.len..^1] = output[0..^1] -proc modExpFee(c: Computation, baseLen, expLen, modLen: Uint256): GasInt = +proc modExpFee(c: Computation, baseLen, expLen, modLen: Uint256, fork: Fork): GasInt = template data: untyped {.dirty.} = c.msg.data - func gasModExp(x: Uint256): Uint256 = + func mulComplexity(x: Uint256): Uint256 = ## Estimates the difficulty of Karatsuba multiplication if x <= 64.u256: x * x elif x <= 1024.u256: x * x div 4.u256 + 96.u256 * x - 3072.u256 else: x * x div 16.u256 + 480.u256 * x - 199680.u256 + func mulComplexityEIP2565(x: Uint256): Uint256 = + # gas = ceil(x div 8) ^ 2 + result = x + 7 + result = result div 8 + result = result * result + let adjExpLen = block: let baseL = baseLen.safeInt @@ -206,23 +212,29 @@ proc modExpFee(c: Computation, baseLen, expLen, modLen: Uint256): GasInt = else: 8.u256 * (expLen - 32.u256) - let gasFee = ( - max(modLen, baseLen).gasModExp * - max(adjExpLen, 1.u256) - ) div GasQuadDivisor + template gasCalc(comp, divisor: untyped): untyped = + ( + max(modLen, baseLen).comp * + max(adjExpLen, 1.u256) + ) div divisor + + let gasFee = if fork >= FkBerlin: gasCalc(mulComplexityEIP2565, 3) + else: gasCalc(mulComplexity, GasQuadDivisor) if gasFee > high(GasInt).u256: raise newException(OutOfGas, "modExp gas overflow") result = gasFee.truncate(GasInt) + if fork >= FkBerlin and result < 200.GasInt: + result = 200.GasInt -proc modExp*(computation: Computation) = +proc modExp*(c: Computation, fork: Fork = FkByzantium) = ## Modular exponentiation precompiled contract ## Yellow Paper Appendix E ## EIP-198 - https://github.com/ethereum/EIPs/blob/master/EIPS/eip-198.md # Parsing the data template data: untyped {.dirty.} = - computation.msg.data + c.msg.data let # lengths Base, Exponent, Modulus baseL = data.rangeToPadded[:Uint256](0, 31) @@ -232,27 +244,27 @@ proc modExp*(computation: Computation) = expLen = expL.safeInt modLen = modL.safeInt - let gasFee = modExpFee(computation, baseL, expL, modL) - computation.gasMeter.consumeGas(gasFee, reason="ModExp Precompile") + let gasFee = modExpFee(c, baseL, expL, modL, fork) + c.gasMeter.consumeGas(gasFee, reason="ModExp Precompile") if baseLen == 0 and modLen == 0: # This is a special case where expLength can be very big. - computation.output = @[] + c.output = @[] return let maxBytes = max(baseLen, max(expLen, modLen)) if maxBytes <= 32: - computation.modExpInternal(baseLen, expLen, modLen, UInt256) + c.modExpInternal(baseLen, expLen, modLen, UInt256) elif maxBytes <= 64: - computation.modExpInternal(baseLen, expLen, modLen, StUint[512]) + c.modExpInternal(baseLen, expLen, modLen, StUint[512]) elif maxBytes <= 128: - computation.modExpInternal(baseLen, expLen, modLen, StUint[1024]) + c.modExpInternal(baseLen, expLen, modLen, StUint[1024]) elif maxBytes <= 256: - computation.modExpInternal(baseLen, expLen, modLen, StUint[2048]) + c.modExpInternal(baseLen, expLen, modLen, StUint[2048]) elif maxBytes <= 512: - computation.modExpInternal(baseLen, expLen, modLen, StUint[4096]) + c.modExpInternal(baseLen, expLen, modLen, StUint[4096]) elif maxBytes <= 1024: - computation.modExpInternal(baseLen, expLen, modLen, StUint[8192]) + c.modExpInternal(baseLen, expLen, modLen, StUint[8192]) else: raise newException(EVMError, "The Nimbus VM doesn't support modular exponentiation with numbers larger than uint8192") @@ -367,7 +379,7 @@ proc execPrecompiles*(computation: Computation, fork: Fork): bool {.inline.} = of paSha256: sha256(computation) of paRipeMd160: ripeMd160(computation) of paIdentity: identity(computation) - of paModExp: modExp(computation) + of paModExp: modExp(computation, fork) of paEcAdd: bn256ecAdd(computation, fork) of paEcMul: bn256ecMul(computation, fork) of paPairing: bn256ecPairing(computation, fork) diff --git a/tests/test_precompiles.nim b/tests/test_precompiles.nim index 197f39961..7ff707f46 100644 --- a/tests/test_precompiles.nim +++ b/tests/test_precompiles.nim @@ -9,11 +9,20 @@ import unittest2, ../nimbus/vm/precompiles, json, stew/byteutils, test_helpers, os, tables, strformat, strutils, eth/trie/db, eth/common, ../nimbus/db/db_chain, ../nimbus/[vm_types, vm_state], ../nimbus/vm/[computation, message], macros, - ../nimbus/vm/blake2b_f + ../nimbus/vm/blake2b_f, ../nimbus/vm/interpreter/vm_forks proc initAddress(i: byte): EthAddress = result[19] = i -template doTest(fixture: JsonNode, address: byte, action: untyped): untyped = +const + eip198Fees = [13056, 13056, 204, 204, 3276, 665, 665, 10649, 1894, + 1894, 30310, 5580, 5580, 89292, 17868, 17868, 285900] + + eip2565Fees = [1360, 1360, 200, 200, 341, 200, 200, 1365, 341, + 341, 5461, 1365, 1365, 21845, 5461, 5461, 87381] + + +template doTest(fixture: JsonNode, address: byte, action: untyped, fees: openArray[int] = [], fork: untyped = 0): untyped = + var i = 0 for test in fixture: let blockNum = 1.u256 # TODO: Check other forks @@ -47,11 +56,24 @@ template doTest(fixture: JsonNode, address: byte, action: untyped): untyped = ) computation = newComputation(vmState, message) # echo "Running ", action.astToStr, " - ", test["name"] - `action`(computation) + + when fees.len > 0: + let initialGas = computation.gasMeter.gasRemaining + + when fork is Fork: + `action`(computation, fork) + else: + `action`(computation) let c = computation.output == expected if not c: echo "Output : " & computation.output.toHex & "\nExpected: " & expected.toHex check c + when fees.len > 0: + let gasFee = initialGas - computation.gasMeter.gasRemaining + check gasFee == fees[i] + + inc i + proc testFixture(fixtures: JsonNode, testStatusIMPL: var TestStatus) = for label, child in fixtures: case toLowerAscii(label) @@ -59,7 +81,9 @@ proc testFixture(fixtures: JsonNode, testStatusIMPL: var TestStatus) = of "sha256": child.doTest(paSha256.ord, sha256) of "ripemd": child.doTest(paRipeMd160.ord, ripemd160) of "identity": child.doTest(paIdentity.ord, identity) - of "modexp": child.doTest(paModExp.ord, modexp) + of "modexp": + child.doTest(paModExp.ord, modexp, eip198Fees) + child.doTest(paModExp.ord, modexp, eip2565Fees, FkBerlin) of "bn256add": child.doTest(paEcAdd.ord, bn256ECAdd) of "bn256mul": child.doTest(paEcMul.ord, bn256ECMul) of "ecpairing": child.doTest(paPairing.ord, bn256ecPairing)