handle custom errors in confirm() calls

This commit is contained in:
Mark Spanbroek 2024-03-21 09:32:15 +01:00 committed by markspanbroek
parent 067e0f2eb7
commit cdb230d30f
7 changed files with 102 additions and 52 deletions

View File

@ -9,6 +9,7 @@ import ./provider
import ./signer
import ./events
import ./errors
import ./errors/conversion
import ./fields
export basics
@ -122,12 +123,14 @@ proc call(contract: Contract,
proc send(contract: Contract,
function: string,
parameters: tuple,
overrides = TransactionOverrides()):
overrides = TransactionOverrides(),
convertCustomErrors: ConvertCustomErrors = nil):
Future[?TransactionResponse] {.async.} =
if signer =? contract.signer:
let transaction = createTransaction(contract, function, parameters, overrides)
let populated = await signer.populateTransaction(transaction)
let txResp = await signer.sendTransaction(populated)
var txResp = await signer.sendTransaction(populated)
txResp.convertCustomErrors = convertCustomErrors
return txResp.some
else:
await call(contract, function, parameters, overrides)
@ -150,7 +153,10 @@ func getErrorTypes(procedure: NimNode): NimNode =
pragma[1].expectKind(nnkBracket)
for error in pragma[1]:
tupl.add error
tupl
if tupl.len == 0:
quote do: tuple[]
else:
tupl
func isGetter(procedure: NimNode): bool =
let pragmas = procedure[4]
@ -192,6 +198,7 @@ func addContractCall(procedure: var NimNode) =
let isGetter = procedure.isGetter
procedure.addOverrides()
let errors = getErrorTypes(procedure)
func call: NimNode =
if returnType.kind == nnkEmpty:
@ -221,7 +228,8 @@ func addContractCall(procedure: var NimNode) =
"unexpected return type, " &
"missing {.view.}, {.pure.} or {.getter.} ?"
.}
return await send(`contract`, `function`, `parameters`, overrides)
let convert = customErrorConversion(`errors`)
return await send(`contract`, `function`, `parameters`, overrides, convert)
procedure[6] =
if procedure.isConstant:
@ -233,8 +241,14 @@ func addErrorHandling(procedure: var NimNode) =
let body = procedure[6]
let errors = getErrorTypes(procedure)
procedure[6] = quote do:
convertCustomErrors[`errors`]:
try:
`body`
except ProviderError as error:
if data =? error.data:
let convert = customErrorConversion(`errors`)
raise convert(error)
else:
raise error
func addFuture(procedure: var NimNode) =
let returntype = procedure[3][0]

View File

@ -1,5 +1,3 @@
import pkg/contractabi
import pkg/contractabi/selector
import ./basics
type SolidityError* = object of EthersError
@ -7,46 +5,3 @@ type SolidityError* = object of EthersError
{.push raises:[].}
template errors*(types) {.pragma.}
func selector(E: type): FunctionSelector =
when compiles(E.arguments):
selector($E, typeof(E.arguments))
else:
selector($E, tuple[])
func matchesSelector(E: type, data: seq[byte]): bool =
const selector = E.selector.toArray
data.len >= 4 and selector[0..<4] == data[0..<4]
func decodeArguments(E: type, data: seq[byte]): auto =
AbiDecoder.decode(data[4..^1], E.arguments)
func decode*[E: SolidityError](_: type E, data: seq[byte]): ?!(ref E) =
if not E.matchesSelector(data):
return failure "unable to decode " & $E & ": selector doesn't match"
when compiles(E.arguments):
without arguments =? E.decodeArguments(data), error:
return failure "unable to decode " & $E & ": " & error.msg
let message = "EVM reverted: " & $E & $arguments
success (ref E)(msg: message, arguments: arguments)
else:
if data.len > 4:
return failure "unable to decode " & $E & ": unread trailing bytes found"
let message = "EVM reverted: " & $E & "()"
success (ref E)(msg: message)
func encode*[E: SolidityError](_: type AbiEncoder, error: ref E): seq[byte] =
result = @(E.selector.toArray)
when compiles(error.arguments):
result &= AbiEncoder.encode(error.arguments)
template convertCustomErrors*[ErrorTypes: tuple](body: untyped): untyped =
try:
body
except ProviderError as error:
block:
if data =? error.data:
for e in ErrorTypes.default.fields:
if error =? typeof(e).decode(data):
raise error
raise error

View File

@ -0,0 +1,12 @@
import ../basics
import ../provider
import ./encoding
func customErrorConversion*(ErrorTypes: type tuple): ConvertCustomErrors =
func convert(error: ref ProviderError): ref EthersError =
if data =? error.data:
for e in ErrorTypes.default.fields:
if error =? typeof(e).decode(data):
return error
return error
convert

View File

@ -0,0 +1,37 @@
import pkg/contractabi
import pkg/contractabi/selector
import ../basics
import ../errors
func selector(E: type): FunctionSelector =
when compiles(E.arguments):
selector($E, typeof(E.arguments))
else:
selector($E, tuple[])
func matchesSelector(E: type, data: seq[byte]): bool =
const selector = E.selector.toArray
data.len >= 4 and selector[0..<4] == data[0..<4]
func decodeArguments(E: type, data: seq[byte]): auto =
AbiDecoder.decode(data[4..^1], E.arguments)
func decode*[E: SolidityError](_: type E, data: seq[byte]): ?!(ref E) =
if not E.matchesSelector(data):
return failure "unable to decode " & $E & ": selector doesn't match"
when compiles(E.arguments):
without arguments =? E.decodeArguments(data), error:
return failure "unable to decode " & $E & ": " & error.msg
let message = "EVM reverted: " & $E & $arguments
success (ref E)(msg: message, arguments: arguments)
else:
if data.len > 4:
return failure "unable to decode " & $E & ": unread trailing bytes found"
let message = "EVM reverted: " & $E & "()"
success (ref E)(msg: message)
func encode*[E: SolidityError](_: type AbiEncoder, error: ref E): seq[byte] =
result = @(E.selector.toArray)
when compiles(error.arguments):
result &= AbiEncoder.encode(error.arguments)

View File

@ -1,9 +1,9 @@
import pkg/chronicles
import pkg/serde
import pkg/stew/byteutils
import ./basics
import ./transaction
import ./blocktag
import ./errors
export basics
export transaction
@ -41,6 +41,9 @@ type
TransactionResponse* = object
provider*: Provider
hash* {.serialize.}: TransactionHash
convertCustomErrors*: ConvertCustomErrors
ConvertCustomErrors* =
proc(error: ref ProviderError): ref EthersError {.gcsafe, raises:[].}
TransactionReceipt* {.serialize.} = object
sender* {.serialize("from"), deserialize("from").}: ?Address
to*: ?Address
@ -267,7 +270,11 @@ proc confirm*(
if txBlockNumber + confirmations.u256 <= blockNumber + 1:
await subscription.unsubscribe()
await tx.provider.ensureSuccess(receipt)
try:
await tx.provider.ensureSuccess(receipt)
except ProviderError as error:
if convert =? tx.convertCustomErrors:
raise convert(error)
return receipt
proc confirm*(

View File

@ -111,3 +111,27 @@ suite "Contract custom errors":
except ErrorWithArguments as error:
check error.arguments.one == 1.u256
check error.arguments.two == true
test "handles transaction confirmation errors":
proc revertsTransaction(contract: TestCustomErrors): ?TransactionResponse
{.contract, errors:[ErrorWithArguments].}
# skip gas estimation
let overrides = TransactionOverrides(gasLimit: some 1000000.u256)
# ensure that transaction is not immediately checked by hardhat
discard await provider.send("evm_setAutomine", @[%false])
let contract = contract.connect(provider.getSigner())
try:
let future = contract.revertsTransaction(overrides = overrides).confirm(0)
await sleepAsync(100.millis) # wait for transaction to be submitted
discard await provider.send("evm_mine", @[]) # mine the transaction
discard await future # wait for confirmation
fail()
except ErrorWithArguments as error:
check error.arguments.one == 1.u256
check error.arguments.two == true
# re-enable auto mining
discard await provider.send("evm_setAutomine", @[%true])

View File

@ -3,6 +3,7 @@ import std/strutils
import pkg/questionable/results
import pkg/contractabi
import pkg/ethers/errors
import pkg/ethers/errors/encoding
suite "Decoding of custom errors":