support custom errors in contract calls

Currently only errors without arguments
This commit is contained in:
Mark Spanbroek 2024-03-19 15:27:51 +01:00 committed by markspanbroek
parent 6b57e56a39
commit 74f15fca9c
6 changed files with 115 additions and 0 deletions

View File

@ -8,11 +8,13 @@ import ./basics
import ./provider
import ./signer
import ./events
import ./errors
import ./fields
export basics
export provider
export events
export errors
logScope:
topics = "ethers contract"
@ -138,6 +140,17 @@ func getParameterTuple(procedure: NimNode): NimNode =
tupl.add name
return tupl
func getErrorTypes(procedure: NimNode): NimNode =
let pragmas = procedure[4]
var tupl = newNimNode(nnkTupleConstr)
for pragma in pragmas:
if pragma.kind == nnkExprColonExpr:
if pragma[0].eqIdent "errors":
pragma[1].expectKind(nnkBracket)
for error in pragma[1]:
tupl.add error
tupl
func isGetter(procedure: NimNode): bool =
let pragmas = procedure[4]
for pragma in pragmas:
@ -215,6 +228,13 @@ func addContractCall(procedure: var NimNode) =
else:
send()
func addErrorHandling(procedure: var NimNode) =
let body = procedure[6]
let errors = getErrorTypes(procedure)
procedure[6] = quote do:
convertCustomErrors[`errors`]:
`body`
func addFuture(procedure: var NimNode) =
let returntype = procedure[3][0]
if returntype.kind != nnkEmpty:
@ -236,6 +256,7 @@ macro contract*(procedure: untyped{nkProcDef|nkMethodDef}): untyped =
var contractcall = copyNimTree(procedure)
contractcall.addContractCall()
contractcall.addErrorHandling()
contractcall.addFuture()
contractcall.addAsyncPragma()
contractcall

View File

@ -15,3 +15,14 @@ func decode*[E: SolidityError](_: type E, data: seq[byte]): ?!(ref E) =
if selector.toArray[0..<4] != data[0..<4]:
return failure "unable to decode " & name & ": signature doesn't match"
success (ref E)()
template convertCustomErrors*[ErrorTypes: tuple](body: untyped): untyped =
try:
body
except ProviderError as error:
block:
if data =? error.data:
for e in ErrorTypes.default.fields:
if error =? typeof(e).decode(data):
raise error
raise error

View File

@ -8,5 +8,6 @@ import ./testTesting
import ./testErc20
import ./testGasEstimation
import ./testErrorDecoding
import ./testCustomErrors
{.warning[UnusedImport]:off.}

View File

@ -0,0 +1,32 @@
import std/json
import pkg/asynctest
import pkg/ethers
import ./hardhat
suite "Contract custom errors":
type
TestCustomErrors = ref object of Contract
SimpleError = object of SolidityError
var contract: TestCustomErrors
var provider: JsonRpcProvider
var snapshot: JsonNode
setup:
provider = JsonRpcProvider.new()
snapshot = await provider.send("evm_snapshot")
let deployment = readDeployment()
let address = !deployment.address(TestCustomErrors)
contract = TestCustomErrors.new(address, provider)
teardown:
discard await provider.send("evm_revert", @[snapshot])
await provider.close()
test "handles simple errors":
proc revertsSimpleError(contract: TestCustomErrors)
{.contract, pure, errors:[SimpleError].}
expect SimpleError:
await contract.revertsSimpleError()

View File

@ -0,0 +1,44 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.4;
contract TestCustomErrors {
error SimpleError();
error ErrorWithArguments(uint256 one, bool two);
error ErrorWithStaticStruct(StaticStruct one, StaticStruct two);
error ErrorWithDynamicStruct(DynamicStruct one, DynamicStruct two);
error ErrorWithDynamicAndStaticStruct(DynamicStruct one, StaticStruct two);
struct StaticStruct {
uint256 a;
uint256 b;
}
struct DynamicStruct {
string a;
uint256 b;
}
function revertsSimpleError() public pure {
revert SimpleError();
}
function revertsErrorWithArguments() public pure {
revert ErrorWithArguments(1, true);
}
function revertsErrorWithStaticStruct() public pure {
revert ErrorWithStaticStruct(StaticStruct(1, 2), StaticStruct(3, 4));
}
function revertsErrorWithDynamicStruct() public pure {
revert ErrorWithDynamicStruct(DynamicStruct("1", 2), DynamicStruct("3", 4));
}
function revertsErrorWithDynamicAndStaticStruct() public pure {
revert ErrorWithDynamicAndStaticStruct(
DynamicStruct("1", 2),
StaticStruct(3, 4)
);
}
}

View File

@ -0,0 +1,6 @@
module.exports = async ({ deployments, getNamedAccounts }) => {
const { deployer } = await getNamedAccounts();
await deployments.deploy("TestCustomErrors", { from: deployer });
};
module.exports.tags = ["TestCustomErrors"];