Rework of ENR decoding code (#709)

- Rework to have exception raise only at rlp decoding and use
result types from then onwards
- Adjust the current API to have result versions and deprecated
the ones which had var Record + bool
- Add PublickKey to the Record object, as this allows us to skip
fromRaw calls whenever access is needed to the public key
- Add a TypedRecord.fromRecord which cannot fail and deprecate
the old one
- Some other minor clean-up & re-ordering
This commit is contained in:
Kim De Mey 2024-06-27 15:15:23 +02:00 committed by GitHub
parent 7f20d79945
commit d7577f59d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 184 additions and 145 deletions

View File

@ -57,9 +57,14 @@ type
## must remain sorted and without duplicate keys. Use the insert func to
## ensure this.
raw*: seq[byte] ## RLP encoded record
publicKey: PublicKey ## Public key of the record
EnrUri* = distinct string
# TODO: I think it makes more sense to have the directly usable types for the
# fields here because in its current for you might as well just access the
# pairs in a Record directly. This would break the current API unless the type
# gets renamed.
TypedRecord* = object
id*: string
secp256k1*: Opt[array[33, byte]]
@ -356,26 +361,20 @@ func tryGet*(r: Record, key: string, T: type): Opt[T] =
## according to type `T`.
get(r, key, T).optValue()
func toTypedRecord*(r: Record): EnrResult[TypedRecord] =
let id = r.tryGet("id", string)
if id.isSome:
var tr: TypedRecord
tr.id = id.get
func fromRecord*(T: type TypedRecord, r: Record): T =
TypedRecord(
id: r.get("id", string).expect("Record must always have id field"),
secp256k1: r.tryGet("secp256k1", array[33, byte]),
ip: r.tryGet("ip", array[4, byte]),
ip6: r.tryGet("ip6", array[16, byte]),
tcp: r.tryGet("tcp", int),
tcp6: r.tryGet("tcp6", int),
udp: r.tryGet("udp", int),
udp6: r.tryGet("udp6", int)
)
template readField(fieldName: untyped) {.dirty.} =
tr.fieldName = tryGet(r, astToStr(fieldName), type(tr.fieldName.get))
readField secp256k1
readField ip
readField ip6
readField tcp
readField tcp6
readField udp
readField udp6
ok(tr)
else:
err("Record without id field")
func toTypedRecord*(r: Record): EnrResult[TypedRecord] {.deprecated: "Please use TypedRecord.fromRecord instead".} =
ok(TypedRecord.fromRecord(r))
func contains*(r: Record, fp: (string, seq[byte])): bool =
# TODO: use FieldPair for this, but that is a bit cumbersome. Perhaps the
@ -387,118 +386,161 @@ func contains*(r: Record, fp: (string, seq[byte])): bool =
false
func verifySignatureV4(
r: Record, sigData: openArray[byte], content: seq[byte]): bool =
let publicKey = r.get(PublicKey)
if publicKey.isNone():
return false
let sig = SignatureNR.fromRaw(sigData)
if sig.isOk():
var h = keccak256.digest(content)
verify(sig[], SkMessage(h.data), publicKey.get)
publicKey: PublicKey, sigData: openArray[byte], content: openArray[byte]): EnrResult[void] =
## Verify the signature for the "v4" identity scheme
let signature = ?SignatureNR.fromRaw(sigData)
let hash = keccak256.digest(content)
if verify(signature, SkMessage(hash.data), publicKey):
ok()
else:
false
err("Signature verfication failed")
template rlpResult(body: untyped): auto =
try:
body
except RlpError:
return err("Invalid RLP list")
func buildRlpContent(bytes: openArray[byte]): EnrResult[seq[byte]] =
## Rebuild the encoded RLP content without the signature. This is used to
## verify the signature.
var rlp = rlpFromBytes(bytes)
let listLen = rlpResult rlp.listLen
doAssert rlp.enterList()
# skip signature
rlpResult rlp.skipElem()
func verifySignature(r: Record): bool {.raises: [RlpError].} =
var rlp = rlpFromBytes(r.raw)
let sz = rlp.listLen
if not rlp.enterList:
return false
let sigData = rlp.read(seq[byte])
let content = block:
var writer = initRlpList(sz - 1)
var reader = rlp
for i in 1 ..< sz:
writer.appendRawBytes(reader.rawData)
reader.skipElem
var writer = initRlpList(listLen - 1)
for i in 1 ..< listLen:
rlpResult:
writer.appendRawBytes(rlp.rawData)
rlp.skipElem()
writer.finish()
var id: Field
if r.getField("id", id) and id.kind == kString:
case id.str
of "v4":
verifySignatureV4(r, sigData, content)
else:
# Unknown Identity Scheme
false
else:
# No Identity Scheme provided
false
ok(content)
func fromBytesAux(r: var Record): bool {.raises: [RlpError].} =
if r.raw.len > maxEnrSize:
return false
func fromBytesAux(T: type Record, s: openArray[byte]): EnrResult[T] =
## Creates ENR from rlp-encoded bytes and verifies the signature.
if s.len > maxEnrSize:
return err("Record exceeds maximum size")
var rlp = rlpFromBytes(r.raw)
var rlp = rlpFromBytes(s)
if not rlp.isList:
return false
return err("Record does not contain valid RLP list")
let sz = rlp.listLen
let sz = rlpResult rlp.listLen
if sz < minRlpListLen or sz mod 2 != 0:
# Wrong rlp object
return false
return err("Wrong RLP list length")
# We already know we are working with a list
doAssert rlp.enterList()
rlp.skipElem() # Skip signature
r.seqNum = rlp.read(uint64)
let
signatureRaw = rlpResult rlp.read(seq[byte])
seqNum = rlpResult rlp.read(uint64)
numPairs = (sz - 2) div 2
let numPairs = (sz - 2) div 2
var
pairs = newSeqOfCap[FieldPair](numPairs)
id: string = ""
pkRaw = Opt.none(seq[byte])
for i in 0 ..< numPairs:
let k = rlp.read(string)
let k = rlpResult rlp.read(string)
case k
of "id":
let id = rlp.read(string)
r.pairs.add((k, Field(kind: kString, str: id)))
id = rlpResult rlp.read(string)
pairs.add((k, Field(kind: kString, str: id)))
of "secp256k1":
let pubkeyData = rlp.read(seq[byte])
r.pairs.add((k, Field(kind: kBytes, bytes: pubkeyData)))
pkRaw = Opt.some rlpResult rlp.read(seq[byte])
pairs.add((k, Field(kind: kBytes, bytes: pkRaw.value())))
of "tcp", "udp", "tcp6", "udp6":
let v = rlp.read(uint16)
r.pairs.add((k, Field(kind: kNum, num: v)))
let v = rlpResult rlp.read(uint16)
pairs.add((k, Field(kind: kNum, num: v)))
else:
# Don't know really what this is supposed to represent so drop it in
# `kBytes` field pair when a single byte or blob.
if rlp.isSingleByte() or rlp.isBlob():
r.pairs.add((k, Field(kind: kBytes, bytes: rlp.read(seq[byte]))))
let bytes = rlpResult rlp.read(seq[byte])
pairs.add((k, Field(kind: kBytes, bytes: bytes)))
elif rlp.isList():
# Not supporting decoding lists as value (especially unknown ones),
# just drop the raw RLP value in there.
r.pairs.add((k, Field(kind: kList, listRaw: @(rlp.rawData()))))
pairs.add((k, Field(kind: kList, listRaw: @(rlpResult rlp.rawData()))))
# Need to skip the element still.
rlp.skipElem()
rlpResult rlp.skipElem()
verifySignature(r)
# Storing the PublicKey in the Record as `fromRaw` is relatively expensive.
let pk: PublicKey =
case id
of "":
return err("No id k:v pair in the ENR")
of "v4":
let content = ?buildRlpContent(s)
if pkRaw.isNone():
return err("No secp256k1 k:v pair in the ENR")
let pk = ?PublicKey.fromRaw(pkRaw.value())
?verifySignatureV4(pk, signatureRaw, content)
pk
else:
return err("Unknown Identity Scheme")
func fromBytes*(r: var Record, s: openArray[byte]): bool =
## Loads ENR from rlp-encoded bytes, and validates the signature.
r.raw = @s
try:
fromBytesAux(r)
except RlpError:
false
ok(Record(
seqNum: seqNum,
pairs: pairs,
raw: @s,
publicKey: pk
))
func fromBase64*(r: var Record, s: string): bool =
## Loads ENR from base64-encoded rlp-encoded bytes, and validates the
func fromBytes*(T: type Record, s: openArray[byte]): EnrResult[T] =
## Creates ENR from rlp-encoded bytes and verifies the signature.
Record.fromBytesAux(s)
func fromBytes*(r: var Record, s: openArray[byte]): bool {.deprecated: "Use the Result[Record] version instead".} =
## Loads ENR from rlp-encoded bytes and verifies the signature.
r = Record.fromBytes(s).valueOr:
return false
true
func fromBase64*(T: type Record, s: string): EnrResult[T] =
## Creates ENR from base64-encoded rlp-encoded bytes and verifies the
## signature.
try:
r.raw = Base64Url.decode(s)
fromBytesAux(r)
except RlpError, Base64Error:
false
let rlpRaw =
try:
Base64Url.decode(s)
except Base64Error:
return err("Base64 decoding error")
func fromURI*(r: var Record, s: string): bool =
## Loads ENR from its text encoding: base64-encoded rlp-encoded bytes,
## prefixed with "enr:". Validates the signature.
Record.fromBytesAux(rlpRaw)
func fromBase64*(r: var Record, s: string): bool {.deprecated: "Use the Result[Record] version instead".} =
## Loads ENR from base64-encoded rlp-encoded bytes and verifies the
## signature.
r = Record.fromBase64(s).valueOr:
return false
true
func fromURI*(T: type Record, s: string): EnrResult[T] =
## Creates ENR from its URI encoding: base64-encoded rlp-encoded bytes,
## prefixed with "enr:". Verifies the signature.
const prefix = "enr:"
if s.startsWith(prefix):
r.fromBase64(s[prefix.len .. ^1])
Record.fromBase64(s[prefix.len .. ^1])
else:
false
err("Invalid URI prefix")
template fromURI*(r: var Record, url: EnrUri): bool =
func fromURI*(r: var Record, s: string): bool {.deprecated: "Use the Result[Record] version instead".} =
## Loads ENR from its URI encoding: base64-encoded rlp-encoded bytes,
## prefixed with "enr:". Verifies the signature.
r = Record.fromURI(s).valueOr:
return false
true
template fromURI*(r: var Record, url: EnrUri): bool {.deprecated: "Use the Result[Record] version instead".} =
fromURI(r, string(url))
func toBase64*(r: Record): string =
@ -551,15 +593,17 @@ func `==`*(a, b: Record): bool = a.raw == b.raw
func read*(
rlp: var Rlp, T: type Record):
T {.raises: [RlpError, ValueError].} =
var res: T
if not rlp.hasData() or not res.fromBytes(rlp.rawData()):
# TODO: This could also just be an invalid signature, would be cleaner to
# split of RLP deserialisation errors from this.
raise newException(ValueError, "Could not deserialize")
T {.raises: [RlpError].} =
if not rlp.hasData():
raise newException(RlpError, "Empty RLP data")
let res = T.fromBytes(rlp.rawData())
if res.isErr:
raise newException(RlpError, $res.error)
rlp.skipElem()
res
res.value
func append*(rlpWriter: var RlpWriter, value: Record) =
rlpWriter.appendRawBytes(value.raw)

View File

@ -15,6 +15,8 @@ import
std/[hashes, net],
./enr
export enr
type
MessageKind* = enum
# Note:

View File

@ -14,7 +14,7 @@ import
stew/arrayops,
results,
../../rlp,
"."/[messages, enr]
"."/messages
from stew/objects import checkedEnumAssign
@ -94,7 +94,7 @@ func decodeMessage*(body: openArray[byte]): Result[Message, cstring] =
return err("Invalid request-id")
func decode[T](rlp: var Rlp, v: var T)
{.nimcall, raises: [RlpError, ValueError].} =
{.nimcall, raises: [RlpError].} =
for k, v in v.fieldPairs:
v = rlp.read(typeof(v))

View File

@ -1,5 +1,5 @@
import
std/[os, strutils, options, net],
std/[os, strutils, net],
../../../eth/keys, ../../../eth/p2p/discoveryv5/enr,
../fuzzing_helpers

View File

@ -32,12 +32,10 @@ suite "ENR test vector tests":
udp = 0x765f
test "Test vector full encode loop":
var r: Record
let valid = r.fromURI(uri)
check valid
let res = toTypedRecord(r)
let res = Record.fromURI(uri)
check res.isOk()
let typedRecord = res.value
let r = res.value()
let typedRecord = TypedRecord.fromRecord(r)
check:
r.seqNum == seqNum
typedRecord.id == id
@ -77,25 +75,22 @@ suite "ENR encoding tests":
testRlpEncodingLoop(enr.value)
test "Empty RLP":
expect ValueError:
expect RlpError:
let _ = rlp.decode([], enr.Record)
var r: Record
check not fromBytes(r, [])
check Record.fromBytes([]).isErr()
test "Invalid RLP":
expect RlpError:
let _ = rlp.decode([byte 0xf7], enr.Record)
var r: Record
check not fromBytes(r, [byte 0xf7])
check Record.fromBytes([byte 0xf7]).isErr()
test "No RLP list":
expect ValueError:
expect RlpError:
let _ = rlp.decode([byte 0x7f], enr.Record)
var r: Record
check not fromBytes(r, [byte 0x7f])
check Record.fromBytes([byte 0x7f]).isErr()
test "ENR with RLP list value":
type
@ -121,15 +116,14 @@ suite "ENR encoding tests":
test "Base64 encode loop":
const encodedBase64 = "-IS4QHCYrYZbAKWCBRlAy5zzaDZXJBGkcnh4MHcBFZntXNFrdvJjX04jRzjzCBOonrkTfj499SZuOh8R33Ls8RRcy5wBgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQPKY0yuDUmstAHYpMa2_oxVtw0RW_QAdpzBQA8yWM0xOIN1ZHCCdl8"
var r: Record
let res = Record.fromBase64(encodedBase64)
check:
r.fromBase64(encodedBase64)
toBase64(r) == encodedBase64
res.isOk()
toBase64(res.value) == encodedBase64
test "Invalid base64":
var r: Record
let valid = r.fromBase64("-IS4QHCYrYZbAKWCBRlAy5zzaDZXJBGkcnhMHcBFZntXNFrdv*jX04jRzjzCBOonrkTfj499SZuOh8R33Ls8RRcy5wBgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQPKY0yuDUmstAHYpMa2_oxVtw0RW_QAdpzBQA8yWM0xOIN1ZHCCdl8")
check not valid
let res = Record.fromBase64("-IS4QHCYrYZbAKWCBRlAy5zzaDZXJBGkcnhMHcBFZntXNFrdv*jX04jRzjzCBOonrkTfj499SZuOh8R33Ls8RRcy5wBgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQPKY0yuDUmstAHYpMa2_oxVtw0RW_QAdpzBQA8yWM0xOIN1ZHCCdl8")
check res.isErr()
test "URI encode loop":
let
@ -141,20 +135,16 @@ suite "ENR encoding tests":
check res.isOk()
let enr = res.value()
let uri = enr.toURI()
var enr2: Record
let valid = enr2.fromURI(uri)
check(valid)
check(enr == enr2)
let res2 = Record.fromURI(uri)
check:
res2.isOk()
enr == res2.value()
test "Invalid URI: empty":
var r: Record
let valid = r.fromURI("")
check not valid
check Record.fromURI("").isErr()
test "Invalid URI: no payload":
var r: Record
let valid = r.fromURI("enr:")
check not valid
check Record.fromURI("enr:").isErr()
suite "ENR init tests":
test "Record.init minimum fields":
@ -163,7 +153,7 @@ suite "ENR init tests":
port = Opt.none(Port)
enr = Record.init(
100, keypair.seckey, Opt.none(IpAddress), port, port)[]
typedEnr = get enr.toTypedRecord()
typedEnr = TypedRecord.fromRecord(enr)
check:
testRlpEncodingLoop(enr)
@ -186,7 +176,7 @@ suite "ENR init tests":
port = Opt.some(Port(9000))
enr = Record.init(
100, keypair.seckey, Opt.some(ip), port, port)[]
typedEnr = get enr.toTypedRecord()
typedEnr = TypedRecord.fromRecord(enr)
check:
typedEnr.ip.isSome()
@ -205,7 +195,7 @@ suite "ENR init tests":
port = Opt.some(Port(9000))
enr = Record.init(
100, keypair.seckey, Opt.some(ip), port, port)[]
typedEnr = get enr.toTypedRecord()
typedEnr = TypedRecord.fromRecord(enr)
check:
typedEnr.ip.isNone()
@ -400,7 +390,7 @@ suite "ENR update tests":
Opt.some(Port(9000)), Opt.some(Port(9000)))
check updated.isOk()
let typedEnr = r.toTypedRecord().get()
let typedEnr = TypedRecord.fromRecord(r)
check:
typedEnr.ip.isSome()
@ -419,7 +409,7 @@ suite "ENR update tests":
Opt.some(Port(9001)), Opt.some(Port(9001)))
check updated.isOk()
let typedEnr = r.toTypedRecord().get()
let typedEnr = TypedRecord.fromRecord(r)
check:
typedEnr.ip.isSome()

View File

@ -117,25 +117,28 @@ type
name: "node" .}: Node
proc parseCmdArg*(T: type enr.Record, p: string): T {.raises: [ValueError].} =
if not fromURI(result, p):
raise newException(ValueError, "Invalid ENR")
let res = enr.Record.fromURI(p)
if res.isErr:
raise newException(ValueError, "Invalid ENR:" & $res.error)
res.value
proc completeCmdArg*(T: type enr.Record, val: string): seq[string] =
return @[]
proc parseCmdArg*(T: type Node, p: string): T {.raises: [ValueError].} =
var record: enr.Record
if not fromURI(record, p):
raise newException(ValueError, "Invalid ENR")
let res = enr.Record.fromURI(p)
if res.isErr:
raise newException(ValueError, "Invalid ENR:" & $res.error)
let n = newNode(record)
let n = newNode(res.value)
if n.isErr:
raise newException(ValueError, $n.error)
if n[].address.isNone():
if n.value.address.isNone():
raise newException(ValueError, "ENR without address")
n[]
n.value
proc completeCmdArg*(T: type Node, val: string): seq[string] =
return @[]