diff --git a/ethers/contract.nim b/ethers/contract.nim index 66a22cf..dddd61b 100644 --- a/ethers/contract.nim +++ b/ethers/contract.nim @@ -9,6 +9,7 @@ import ./provider import ./signer import ./events import ./errors +import ./errors/conversion import ./fields export basics @@ -122,12 +123,14 @@ proc call(contract: Contract, proc send(contract: Contract, function: string, parameters: tuple, - overrides = TransactionOverrides()): + overrides = TransactionOverrides(), + convertCustomErrors: ConvertCustomErrors = nil): Future[?TransactionResponse] {.async.} = if signer =? contract.signer: let transaction = createTransaction(contract, function, parameters, overrides) let populated = await signer.populateTransaction(transaction) - let txResp = await signer.sendTransaction(populated) + var txResp = await signer.sendTransaction(populated) + txResp.convertCustomErrors = convertCustomErrors return txResp.some else: await call(contract, function, parameters, overrides) @@ -150,7 +153,10 @@ func getErrorTypes(procedure: NimNode): NimNode = pragma[1].expectKind(nnkBracket) for error in pragma[1]: tupl.add error - tupl + if tupl.len == 0: + quote do: tuple[] + else: + tupl func isGetter(procedure: NimNode): bool = let pragmas = procedure[4] @@ -192,6 +198,7 @@ func addContractCall(procedure: var NimNode) = let isGetter = procedure.isGetter procedure.addOverrides() + let errors = getErrorTypes(procedure) func call: NimNode = if returnType.kind == nnkEmpty: @@ -221,7 +228,8 @@ func addContractCall(procedure: var NimNode) = "unexpected return type, " & "missing {.view.}, {.pure.} or {.getter.} ?" .} - return await send(`contract`, `function`, `parameters`, overrides) + let convert = customErrorConversion(`errors`) + return await send(`contract`, `function`, `parameters`, overrides, convert) procedure[6] = if procedure.isConstant: @@ -233,8 +241,14 @@ func addErrorHandling(procedure: var NimNode) = let body = procedure[6] let errors = getErrorTypes(procedure) procedure[6] = quote do: - convertCustomErrors[`errors`]: + try: `body` + except ProviderError as error: + if data =? error.data: + let convert = customErrorConversion(`errors`) + raise convert(error) + else: + raise error func addFuture(procedure: var NimNode) = let returntype = procedure[3][0] diff --git a/ethers/errors.nim b/ethers/errors.nim index 5205249..b45aee4 100644 --- a/ethers/errors.nim +++ b/ethers/errors.nim @@ -1,5 +1,3 @@ -import pkg/contractabi -import pkg/contractabi/selector import ./basics type SolidityError* = object of EthersError @@ -7,46 +5,3 @@ type SolidityError* = object of EthersError {.push raises:[].} template errors*(types) {.pragma.} - -func selector(E: type): FunctionSelector = - when compiles(E.arguments): - selector($E, typeof(E.arguments)) - else: - selector($E, tuple[]) - -func matchesSelector(E: type, data: seq[byte]): bool = - const selector = E.selector.toArray - data.len >= 4 and selector[0..<4] == data[0..<4] - -func decodeArguments(E: type, data: seq[byte]): auto = - AbiDecoder.decode(data[4..^1], E.arguments) - -func decode*[E: SolidityError](_: type E, data: seq[byte]): ?!(ref E) = - if not E.matchesSelector(data): - return failure "unable to decode " & $E & ": selector doesn't match" - when compiles(E.arguments): - without arguments =? E.decodeArguments(data), error: - return failure "unable to decode " & $E & ": " & error.msg - let message = "EVM reverted: " & $E & $arguments - success (ref E)(msg: message, arguments: arguments) - else: - if data.len > 4: - return failure "unable to decode " & $E & ": unread trailing bytes found" - let message = "EVM reverted: " & $E & "()" - success (ref E)(msg: message) - -func encode*[E: SolidityError](_: type AbiEncoder, error: ref E): seq[byte] = - result = @(E.selector.toArray) - when compiles(error.arguments): - result &= AbiEncoder.encode(error.arguments) - -template convertCustomErrors*[ErrorTypes: tuple](body: untyped): untyped = - try: - body - except ProviderError as error: - block: - if data =? error.data: - for e in ErrorTypes.default.fields: - if error =? typeof(e).decode(data): - raise error - raise error diff --git a/ethers/errors/conversion.nim b/ethers/errors/conversion.nim new file mode 100644 index 0000000..22a43aa --- /dev/null +++ b/ethers/errors/conversion.nim @@ -0,0 +1,12 @@ +import ../basics +import ../provider +import ./encoding + +func customErrorConversion*(ErrorTypes: type tuple): ConvertCustomErrors = + func convert(error: ref ProviderError): ref EthersError = + if data =? error.data: + for e in ErrorTypes.default.fields: + if error =? typeof(e).decode(data): + return error + return error + convert diff --git a/ethers/errors/encoding.nim b/ethers/errors/encoding.nim new file mode 100644 index 0000000..72372eb --- /dev/null +++ b/ethers/errors/encoding.nim @@ -0,0 +1,37 @@ +import pkg/contractabi +import pkg/contractabi/selector +import ../basics +import ../errors + +func selector(E: type): FunctionSelector = + when compiles(E.arguments): + selector($E, typeof(E.arguments)) + else: + selector($E, tuple[]) + +func matchesSelector(E: type, data: seq[byte]): bool = + const selector = E.selector.toArray + data.len >= 4 and selector[0..<4] == data[0..<4] + +func decodeArguments(E: type, data: seq[byte]): auto = + AbiDecoder.decode(data[4..^1], E.arguments) + +func decode*[E: SolidityError](_: type E, data: seq[byte]): ?!(ref E) = + if not E.matchesSelector(data): + return failure "unable to decode " & $E & ": selector doesn't match" + when compiles(E.arguments): + without arguments =? E.decodeArguments(data), error: + return failure "unable to decode " & $E & ": " & error.msg + let message = "EVM reverted: " & $E & $arguments + success (ref E)(msg: message, arguments: arguments) + else: + if data.len > 4: + return failure "unable to decode " & $E & ": unread trailing bytes found" + let message = "EVM reverted: " & $E & "()" + success (ref E)(msg: message) + +func encode*[E: SolidityError](_: type AbiEncoder, error: ref E): seq[byte] = + result = @(E.selector.toArray) + when compiles(error.arguments): + result &= AbiEncoder.encode(error.arguments) + diff --git a/ethers/provider.nim b/ethers/provider.nim index 4b585fe..efce84c 100644 --- a/ethers/provider.nim +++ b/ethers/provider.nim @@ -1,9 +1,9 @@ import pkg/chronicles import pkg/serde -import pkg/stew/byteutils import ./basics import ./transaction import ./blocktag +import ./errors export basics export transaction @@ -41,6 +41,9 @@ type TransactionResponse* = object provider*: Provider hash* {.serialize.}: TransactionHash + convertCustomErrors*: ConvertCustomErrors + ConvertCustomErrors* = + proc(error: ref ProviderError): ref EthersError {.gcsafe, raises:[].} TransactionReceipt* {.serialize.} = object sender* {.serialize("from"), deserialize("from").}: ?Address to*: ?Address @@ -267,7 +270,11 @@ proc confirm*( if txBlockNumber + confirmations.u256 <= blockNumber + 1: await subscription.unsubscribe() - await tx.provider.ensureSuccess(receipt) + try: + await tx.provider.ensureSuccess(receipt) + except ProviderError as error: + if convert =? tx.convertCustomErrors: + raise convert(error) return receipt proc confirm*( diff --git a/testmodule/testCustomErrors.nim b/testmodule/testCustomErrors.nim index 92156ef..752ea86 100644 --- a/testmodule/testCustomErrors.nim +++ b/testmodule/testCustomErrors.nim @@ -111,3 +111,27 @@ suite "Contract custom errors": except ErrorWithArguments as error: check error.arguments.one == 1.u256 check error.arguments.two == true + + test "handles transaction confirmation errors": + proc revertsTransaction(contract: TestCustomErrors): ?TransactionResponse + {.contract, errors:[ErrorWithArguments].} + + # skip gas estimation + let overrides = TransactionOverrides(gasLimit: some 1000000.u256) + + # ensure that transaction is not immediately checked by hardhat + discard await provider.send("evm_setAutomine", @[%false]) + + let contract = contract.connect(provider.getSigner()) + try: + let future = contract.revertsTransaction(overrides = overrides).confirm(0) + await sleepAsync(100.millis) # wait for transaction to be submitted + discard await provider.send("evm_mine", @[]) # mine the transaction + discard await future # wait for confirmation + fail() + except ErrorWithArguments as error: + check error.arguments.one == 1.u256 + check error.arguments.two == true + + # re-enable auto mining + discard await provider.send("evm_setAutomine", @[%true]) diff --git a/testmodule/testErrorDecoding.nim b/testmodule/testErrorDecoding.nim index a63838a..8ae4cea 100644 --- a/testmodule/testErrorDecoding.nim +++ b/testmodule/testErrorDecoding.nim @@ -3,6 +3,7 @@ import std/strutils import pkg/questionable/results import pkg/contractabi import pkg/ethers/errors +import pkg/ethers/errors/encoding suite "Decoding of custom errors":