diff --git a/ethers/contract.nim b/ethers/contract.nim index d73cb0f..ae2b4fe 100644 --- a/ethers/contract.nim +++ b/ethers/contract.nim @@ -8,11 +8,13 @@ import ./basics import ./provider import ./signer import ./events +import ./errors import ./fields export basics export provider export events +export errors logScope: topics = "ethers contract" @@ -138,6 +140,17 @@ func getParameterTuple(procedure: NimNode): NimNode = tupl.add name return tupl +func getErrorTypes(procedure: NimNode): NimNode = + let pragmas = procedure[4] + var tupl = newNimNode(nnkTupleConstr) + for pragma in pragmas: + if pragma.kind == nnkExprColonExpr: + if pragma[0].eqIdent "errors": + pragma[1].expectKind(nnkBracket) + for error in pragma[1]: + tupl.add error + tupl + func isGetter(procedure: NimNode): bool = let pragmas = procedure[4] for pragma in pragmas: @@ -215,6 +228,13 @@ func addContractCall(procedure: var NimNode) = else: send() +func addErrorHandling(procedure: var NimNode) = + let body = procedure[6] + let errors = getErrorTypes(procedure) + procedure[6] = quote do: + convertCustomErrors[`errors`]: + `body` + func addFuture(procedure: var NimNode) = let returntype = procedure[3][0] if returntype.kind != nnkEmpty: @@ -236,6 +256,7 @@ macro contract*(procedure: untyped{nkProcDef|nkMethodDef}): untyped = var contractcall = copyNimTree(procedure) contractcall.addContractCall() + contractcall.addErrorHandling() contractcall.addFuture() contractcall.addAsyncPragma() contractcall diff --git a/ethers/errors.nim b/ethers/errors.nim index 508e485..f302bbd 100644 --- a/ethers/errors.nim +++ b/ethers/errors.nim @@ -15,3 +15,14 @@ func decode*[E: SolidityError](_: type E, data: seq[byte]): ?!(ref E) = if selector.toArray[0..<4] != data[0..<4]: return failure "unable to decode " & name & ": signature doesn't match" success (ref E)() + +template convertCustomErrors*[ErrorTypes: tuple](body: untyped): untyped = + try: + body + except ProviderError as error: + block: + if data =? error.data: + for e in ErrorTypes.default.fields: + if error =? typeof(e).decode(data): + raise error + raise error diff --git a/testmodule/test.nim b/testmodule/test.nim index 854579b..fcafd80 100644 --- a/testmodule/test.nim +++ b/testmodule/test.nim @@ -8,5 +8,6 @@ import ./testTesting import ./testErc20 import ./testGasEstimation import ./testErrorDecoding +import ./testCustomErrors {.warning[UnusedImport]:off.} diff --git a/testmodule/testCustomErrors.nim b/testmodule/testCustomErrors.nim new file mode 100644 index 0000000..e01d90d --- /dev/null +++ b/testmodule/testCustomErrors.nim @@ -0,0 +1,32 @@ +import std/json +import pkg/asynctest +import pkg/ethers +import ./hardhat + +suite "Contract custom errors": + + type + TestCustomErrors = ref object of Contract + SimpleError = object of SolidityError + + var contract: TestCustomErrors + var provider: JsonRpcProvider + var snapshot: JsonNode + + setup: + provider = JsonRpcProvider.new() + snapshot = await provider.send("evm_snapshot") + let deployment = readDeployment() + let address = !deployment.address(TestCustomErrors) + contract = TestCustomErrors.new(address, provider) + + teardown: + discard await provider.send("evm_revert", @[snapshot]) + await provider.close() + + test "handles simple errors": + proc revertsSimpleError(contract: TestCustomErrors) + {.contract, pure, errors:[SimpleError].} + + expect SimpleError: + await contract.revertsSimpleError() diff --git a/testnode/contracts/TestCustomErrors.sol b/testnode/contracts/TestCustomErrors.sol new file mode 100644 index 0000000..b3ca90f --- /dev/null +++ b/testnode/contracts/TestCustomErrors.sol @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.4; + +contract TestCustomErrors { + + error SimpleError(); + error ErrorWithArguments(uint256 one, bool two); + error ErrorWithStaticStruct(StaticStruct one, StaticStruct two); + error ErrorWithDynamicStruct(DynamicStruct one, DynamicStruct two); + error ErrorWithDynamicAndStaticStruct(DynamicStruct one, StaticStruct two); + + struct StaticStruct { + uint256 a; + uint256 b; + } + + struct DynamicStruct { + string a; + uint256 b; + } + + function revertsSimpleError() public pure { + revert SimpleError(); + } + + function revertsErrorWithArguments() public pure { + revert ErrorWithArguments(1, true); + } + + function revertsErrorWithStaticStruct() public pure { + revert ErrorWithStaticStruct(StaticStruct(1, 2), StaticStruct(3, 4)); + } + + function revertsErrorWithDynamicStruct() public pure { + revert ErrorWithDynamicStruct(DynamicStruct("1", 2), DynamicStruct("3", 4)); + } + + function revertsErrorWithDynamicAndStaticStruct() public pure { + revert ErrorWithDynamicAndStaticStruct( + DynamicStruct("1", 2), + StaticStruct(3, 4) + ); + } +} diff --git a/testnode/deploy/testcustomerrors.js b/testnode/deploy/testcustomerrors.js new file mode 100644 index 0000000..3757b54 --- /dev/null +++ b/testnode/deploy/testcustomerrors.js @@ -0,0 +1,6 @@ +module.exports = async ({ deployments, getNamedAccounts }) => { + const { deployer } = await getNamedAccounts(); + await deployments.deploy("TestCustomErrors", { from: deployer }); +}; + +module.exports.tags = ["TestCustomErrors"];