support custom errors with arguments

This commit is contained in:
Mark Spanbroek 2024-03-20 11:24:53 +01:00 committed by markspanbroek
parent 74f15fca9c
commit 9c76803302
3 changed files with 53 additions and 8 deletions

View File

@ -1,3 +1,4 @@
import pkg/contractabi
import pkg/contractabi/selector import pkg/contractabi/selector
import ./basics import ./basics
@ -7,14 +8,33 @@ type SolidityError* = object of EthersError
template errors*(types) {.pragma.} template errors*(types) {.pragma.}
func selector(E: type): FunctionSelector =
when compiles(E.arguments):
selector($E, typeof(E.arguments))
else:
selector($E, tuple[])
func matchesSelector(E: type, data: seq[byte]): bool =
const selector = E.selector.toArray
data.len >= 4 and selector[0..<4] == data[0..<4]
func decodeArguments(E: type, data: seq[byte]): auto =
AbiDecoder.decode(data[4..^1], E.arguments)
func decode*[E: SolidityError](_: type E, data: seq[byte]): ?!(ref E) = func decode*[E: SolidityError](_: type E, data: seq[byte]): ?!(ref E) =
const name = $E if not E.matchesSelector(data):
const selector = selector(name, typeof(())) return failure "unable to decode " & $E & ": selector doesn't match"
if data.len < 4: when compiles(E.arguments):
return failure "unable to decode " & name & ": signature too short" without arguments =? E.decodeArguments(data), error:
if selector.toArray[0..<4] != data[0..<4]: return failure "unable to decode " & $E & ": " & error.msg
return failure "unable to decode " & name & ": signature doesn't match" success (ref E)(arguments: arguments)
success (ref E)() else:
success (ref E)()
func encode*[E: SolidityError](_: type AbiEncoder, error: ref E): seq[byte] =
result = @(E.selector.toArray)
when compiles(error.arguments):
result &= AbiEncoder.encode(error.arguments)
template convertCustomErrors*[ErrorTypes: tuple](body: untyped): untyped = template convertCustomErrors*[ErrorTypes: tuple](body: untyped): untyped =
try: try:

View File

@ -8,6 +8,8 @@ suite "Contract custom errors":
type type
TestCustomErrors = ref object of Contract TestCustomErrors = ref object of Contract
SimpleError = object of SolidityError SimpleError = object of SolidityError
ErrorWithArguments = object of SolidityError
arguments: tuple[one: UInt256, two: bool]
var contract: TestCustomErrors var contract: TestCustomErrors
var provider: JsonRpcProvider var provider: JsonRpcProvider
@ -30,3 +32,14 @@ suite "Contract custom errors":
expect SimpleError: expect SimpleError:
await contract.revertsSimpleError() await contract.revertsSimpleError()
test "handles error with arguments":
proc revertsErrorWithArguments(contract: TestCustomErrors)
{.contract, pure, errors:[ErrorWithArguments].}
try:
await contract.revertsErrorWithArguments()
fail()
except ErrorWithArguments as error:
check error.arguments.one == 1
check error.arguments.two == true

View File

@ -1,10 +1,14 @@
import std/unittest import std/unittest
import pkg/questionable/results import pkg/questionable/results
import pkg/contractabi
import pkg/ethers/errors import pkg/ethers/errors
suite "Decoding of custom errors": suite "Decoding of custom errors":
type SimpleError = object of SolidityError type
SimpleError = object of SolidityError
ErrorWithArguments = object of SolidityError
arguments: tuple[one: UInt256, two: bool]
test "decodes a simple error": test "decodes a simple error":
let decoded = SimpleError.decode(@[0xc2'u8, 0xbb, 0x94, 0x7c]) let decoded = SimpleError.decode(@[0xc2'u8, 0xbb, 0x94, 0x7c])
@ -12,6 +16,14 @@ suite "Decoding of custom errors":
check decoded.isSuccess check decoded.isSuccess
check (!decoded) != nil check (!decoded) != nil
test "decodes error with arguments":
let expected = (ref ErrorWithArguments)(arguments: (1.u256, true))
let encoded = AbiEncoder.encode(expected)
let decoded = ErrorWithArguments.decode(encoded)
check decoded.isSuccess
check (!decoded).arguments.one == 1.u256
check (!decoded).arguments.two == true
test "returns failure when decoding fails": test "returns failure when decoding fails":
let invalid = @[0xc2'u8, 0xbb, 0x94, 0x0] # last byte is wrong let invalid = @[0xc2'u8, 0xbb, 0x94, 0x0] # last byte is wrong
let decoded = SimpleError.decode(invalid) let decoded = SimpleError.decode(invalid)