Rework of ENR decoding code (#709)

- Rework to have exception raise only at rlp decoding and use
result types from then onwards
- Adjust the current API to have result versions and deprecated
the ones which had var Record + bool
- Add PublickKey to the Record object, as this allows us to skip
fromRaw calls whenever access is needed to the public key
- Add a TypedRecord.fromRecord which cannot fail and deprecate
the old one
- Some other minor clean-up & re-ordering
This commit is contained in:
Kim De Mey 2024-06-27 15:15:23 +02:00 committed by GitHub
parent 7f20d79945
commit d7577f59d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 184 additions and 145 deletions

View File

@ -57,9 +57,14 @@ type
## must remain sorted and without duplicate keys. Use the insert func to ## must remain sorted and without duplicate keys. Use the insert func to
## ensure this. ## ensure this.
raw*: seq[byte] ## RLP encoded record raw*: seq[byte] ## RLP encoded record
publicKey: PublicKey ## Public key of the record
EnrUri* = distinct string EnrUri* = distinct string
# TODO: I think it makes more sense to have the directly usable types for the
# fields here because in its current for you might as well just access the
# pairs in a Record directly. This would break the current API unless the type
# gets renamed.
TypedRecord* = object TypedRecord* = object
id*: string id*: string
secp256k1*: Opt[array[33, byte]] secp256k1*: Opt[array[33, byte]]
@ -356,26 +361,20 @@ func tryGet*(r: Record, key: string, T: type): Opt[T] =
## according to type `T`. ## according to type `T`.
get(r, key, T).optValue() get(r, key, T).optValue()
func toTypedRecord*(r: Record): EnrResult[TypedRecord] = func fromRecord*(T: type TypedRecord, r: Record): T =
let id = r.tryGet("id", string) TypedRecord(
if id.isSome: id: r.get("id", string).expect("Record must always have id field"),
var tr: TypedRecord secp256k1: r.tryGet("secp256k1", array[33, byte]),
tr.id = id.get ip: r.tryGet("ip", array[4, byte]),
ip6: r.tryGet("ip6", array[16, byte]),
tcp: r.tryGet("tcp", int),
tcp6: r.tryGet("tcp6", int),
udp: r.tryGet("udp", int),
udp6: r.tryGet("udp6", int)
)
template readField(fieldName: untyped) {.dirty.} = func toTypedRecord*(r: Record): EnrResult[TypedRecord] {.deprecated: "Please use TypedRecord.fromRecord instead".} =
tr.fieldName = tryGet(r, astToStr(fieldName), type(tr.fieldName.get)) ok(TypedRecord.fromRecord(r))
readField secp256k1
readField ip
readField ip6
readField tcp
readField tcp6
readField udp
readField udp6
ok(tr)
else:
err("Record without id field")
func contains*(r: Record, fp: (string, seq[byte])): bool = func contains*(r: Record, fp: (string, seq[byte])): bool =
# TODO: use FieldPair for this, but that is a bit cumbersome. Perhaps the # TODO: use FieldPair for this, but that is a bit cumbersome. Perhaps the
@ -387,118 +386,161 @@ func contains*(r: Record, fp: (string, seq[byte])): bool =
false false
func verifySignatureV4( func verifySignatureV4(
r: Record, sigData: openArray[byte], content: seq[byte]): bool = publicKey: PublicKey, sigData: openArray[byte], content: openArray[byte]): EnrResult[void] =
let publicKey = r.get(PublicKey) ## Verify the signature for the "v4" identity scheme
if publicKey.isNone(): let signature = ?SignatureNR.fromRaw(sigData)
return false let hash = keccak256.digest(content)
if verify(signature, SkMessage(hash.data), publicKey):
let sig = SignatureNR.fromRaw(sigData) ok()
if sig.isOk():
var h = keccak256.digest(content)
verify(sig[], SkMessage(h.data), publicKey.get)
else: else:
false err("Signature verfication failed")
template rlpResult(body: untyped): auto =
try:
body
except RlpError:
return err("Invalid RLP list")
func buildRlpContent(bytes: openArray[byte]): EnrResult[seq[byte]] =
## Rebuild the encoded RLP content without the signature. This is used to
## verify the signature.
var rlp = rlpFromBytes(bytes)
let listLen = rlpResult rlp.listLen
doAssert rlp.enterList()
# skip signature
rlpResult rlp.skipElem()
func verifySignature(r: Record): bool {.raises: [RlpError].} =
var rlp = rlpFromBytes(r.raw)
let sz = rlp.listLen
if not rlp.enterList:
return false
let sigData = rlp.read(seq[byte])
let content = block: let content = block:
var writer = initRlpList(sz - 1) var writer = initRlpList(listLen - 1)
var reader = rlp for i in 1 ..< listLen:
for i in 1 ..< sz: rlpResult:
writer.appendRawBytes(reader.rawData) writer.appendRawBytes(rlp.rawData)
reader.skipElem rlp.skipElem()
writer.finish() writer.finish()
var id: Field ok(content)
if r.getField("id", id) and id.kind == kString:
case id.str
of "v4":
verifySignatureV4(r, sigData, content)
else:
# Unknown Identity Scheme
false
else:
# No Identity Scheme provided
false
func fromBytesAux(r: var Record): bool {.raises: [RlpError].} = func fromBytesAux(T: type Record, s: openArray[byte]): EnrResult[T] =
if r.raw.len > maxEnrSize: ## Creates ENR from rlp-encoded bytes and verifies the signature.
return false if s.len > maxEnrSize:
return err("Record exceeds maximum size")
var rlp = rlpFromBytes(r.raw) var rlp = rlpFromBytes(s)
if not rlp.isList: if not rlp.isList:
return false return err("Record does not contain valid RLP list")
let sz = rlp.listLen let sz = rlpResult rlp.listLen
if sz < minRlpListLen or sz mod 2 != 0: if sz < minRlpListLen or sz mod 2 != 0:
# Wrong rlp object return err("Wrong RLP list length")
return false
# We already know we are working with a list # We already know we are working with a list
doAssert rlp.enterList() doAssert rlp.enterList()
rlp.skipElem() # Skip signature
r.seqNum = rlp.read(uint64) let
signatureRaw = rlpResult rlp.read(seq[byte])
seqNum = rlpResult rlp.read(uint64)
numPairs = (sz - 2) div 2
let numPairs = (sz - 2) div 2 var
pairs = newSeqOfCap[FieldPair](numPairs)
id: string = ""
pkRaw = Opt.none(seq[byte])
for i in 0 ..< numPairs: for i in 0 ..< numPairs:
let k = rlp.read(string) let k = rlpResult rlp.read(string)
case k case k
of "id": of "id":
let id = rlp.read(string) id = rlpResult rlp.read(string)
r.pairs.add((k, Field(kind: kString, str: id))) pairs.add((k, Field(kind: kString, str: id)))
of "secp256k1": of "secp256k1":
let pubkeyData = rlp.read(seq[byte]) pkRaw = Opt.some rlpResult rlp.read(seq[byte])
r.pairs.add((k, Field(kind: kBytes, bytes: pubkeyData))) pairs.add((k, Field(kind: kBytes, bytes: pkRaw.value())))
of "tcp", "udp", "tcp6", "udp6": of "tcp", "udp", "tcp6", "udp6":
let v = rlp.read(uint16) let v = rlpResult rlp.read(uint16)
r.pairs.add((k, Field(kind: kNum, num: v))) pairs.add((k, Field(kind: kNum, num: v)))
else: else:
# Don't know really what this is supposed to represent so drop it in # Don't know really what this is supposed to represent so drop it in
# `kBytes` field pair when a single byte or blob. # `kBytes` field pair when a single byte or blob.
if rlp.isSingleByte() or rlp.isBlob(): if rlp.isSingleByte() or rlp.isBlob():
r.pairs.add((k, Field(kind: kBytes, bytes: rlp.read(seq[byte])))) let bytes = rlpResult rlp.read(seq[byte])
pairs.add((k, Field(kind: kBytes, bytes: bytes)))
elif rlp.isList(): elif rlp.isList():
# Not supporting decoding lists as value (especially unknown ones), # Not supporting decoding lists as value (especially unknown ones),
# just drop the raw RLP value in there. # just drop the raw RLP value in there.
r.pairs.add((k, Field(kind: kList, listRaw: @(rlp.rawData())))) pairs.add((k, Field(kind: kList, listRaw: @(rlpResult rlp.rawData()))))
# Need to skip the element still. # Need to skip the element still.
rlp.skipElem() rlpResult rlp.skipElem()
verifySignature(r) # Storing the PublicKey in the Record as `fromRaw` is relatively expensive.
let pk: PublicKey =
case id
of "":
return err("No id k:v pair in the ENR")
of "v4":
let content = ?buildRlpContent(s)
if pkRaw.isNone():
return err("No secp256k1 k:v pair in the ENR")
let pk = ?PublicKey.fromRaw(pkRaw.value())
?verifySignatureV4(pk, signatureRaw, content)
pk
else:
return err("Unknown Identity Scheme")
func fromBytes*(r: var Record, s: openArray[byte]): bool = ok(Record(
## Loads ENR from rlp-encoded bytes, and validates the signature. seqNum: seqNum,
r.raw = @s pairs: pairs,
try: raw: @s,
fromBytesAux(r) publicKey: pk
except RlpError: ))
false
func fromBase64*(r: var Record, s: string): bool = func fromBytes*(T: type Record, s: openArray[byte]): EnrResult[T] =
## Loads ENR from base64-encoded rlp-encoded bytes, and validates the ## Creates ENR from rlp-encoded bytes and verifies the signature.
Record.fromBytesAux(s)
func fromBytes*(r: var Record, s: openArray[byte]): bool {.deprecated: "Use the Result[Record] version instead".} =
## Loads ENR from rlp-encoded bytes and verifies the signature.
r = Record.fromBytes(s).valueOr:
return false
true
func fromBase64*(T: type Record, s: string): EnrResult[T] =
## Creates ENR from base64-encoded rlp-encoded bytes and verifies the
## signature. ## signature.
try: let rlpRaw =
r.raw = Base64Url.decode(s) try:
fromBytesAux(r) Base64Url.decode(s)
except RlpError, Base64Error: except Base64Error:
false return err("Base64 decoding error")
func fromURI*(r: var Record, s: string): bool = Record.fromBytesAux(rlpRaw)
## Loads ENR from its text encoding: base64-encoded rlp-encoded bytes,
## prefixed with "enr:". Validates the signature. func fromBase64*(r: var Record, s: string): bool {.deprecated: "Use the Result[Record] version instead".} =
## Loads ENR from base64-encoded rlp-encoded bytes and verifies the
## signature.
r = Record.fromBase64(s).valueOr:
return false
true
func fromURI*(T: type Record, s: string): EnrResult[T] =
## Creates ENR from its URI encoding: base64-encoded rlp-encoded bytes,
## prefixed with "enr:". Verifies the signature.
const prefix = "enr:" const prefix = "enr:"
if s.startsWith(prefix): if s.startsWith(prefix):
r.fromBase64(s[prefix.len .. ^1]) Record.fromBase64(s[prefix.len .. ^1])
else: else:
false err("Invalid URI prefix")
template fromURI*(r: var Record, url: EnrUri): bool =
func fromURI*(r: var Record, s: string): bool {.deprecated: "Use the Result[Record] version instead".} =
## Loads ENR from its URI encoding: base64-encoded rlp-encoded bytes,
## prefixed with "enr:". Verifies the signature.
r = Record.fromURI(s).valueOr:
return false
true
template fromURI*(r: var Record, url: EnrUri): bool {.deprecated: "Use the Result[Record] version instead".} =
fromURI(r, string(url)) fromURI(r, string(url))
func toBase64*(r: Record): string = func toBase64*(r: Record): string =
@ -551,15 +593,17 @@ func `==`*(a, b: Record): bool = a.raw == b.raw
func read*( func read*(
rlp: var Rlp, T: type Record): rlp: var Rlp, T: type Record):
T {.raises: [RlpError, ValueError].} = T {.raises: [RlpError].} =
var res: T if not rlp.hasData():
if not rlp.hasData() or not res.fromBytes(rlp.rawData()): raise newException(RlpError, "Empty RLP data")
# TODO: This could also just be an invalid signature, would be cleaner to
# split of RLP deserialisation errors from this. let res = T.fromBytes(rlp.rawData())
raise newException(ValueError, "Could not deserialize") if res.isErr:
raise newException(RlpError, $res.error)
rlp.skipElem() rlp.skipElem()
res res.value
func append*(rlpWriter: var RlpWriter, value: Record) = func append*(rlpWriter: var RlpWriter, value: Record) =
rlpWriter.appendRawBytes(value.raw) rlpWriter.appendRawBytes(value.raw)

View File

@ -15,6 +15,8 @@ import
std/[hashes, net], std/[hashes, net],
./enr ./enr
export enr
type type
MessageKind* = enum MessageKind* = enum
# Note: # Note:

View File

@ -14,7 +14,7 @@ import
stew/arrayops, stew/arrayops,
results, results,
../../rlp, ../../rlp,
"."/[messages, enr] "."/messages
from stew/objects import checkedEnumAssign from stew/objects import checkedEnumAssign
@ -94,7 +94,7 @@ func decodeMessage*(body: openArray[byte]): Result[Message, cstring] =
return err("Invalid request-id") return err("Invalid request-id")
func decode[T](rlp: var Rlp, v: var T) func decode[T](rlp: var Rlp, v: var T)
{.nimcall, raises: [RlpError, ValueError].} = {.nimcall, raises: [RlpError].} =
for k, v in v.fieldPairs: for k, v in v.fieldPairs:
v = rlp.read(typeof(v)) v = rlp.read(typeof(v))

View File

@ -1,5 +1,5 @@
import import
std/[os, strutils, options, net], std/[os, strutils, net],
../../../eth/keys, ../../../eth/p2p/discoveryv5/enr, ../../../eth/keys, ../../../eth/p2p/discoveryv5/enr,
../fuzzing_helpers ../fuzzing_helpers

View File

@ -32,12 +32,10 @@ suite "ENR test vector tests":
udp = 0x765f udp = 0x765f
test "Test vector full encode loop": test "Test vector full encode loop":
var r: Record let res = Record.fromURI(uri)
let valid = r.fromURI(uri)
check valid
let res = toTypedRecord(r)
check res.isOk() check res.isOk()
let typedRecord = res.value let r = res.value()
let typedRecord = TypedRecord.fromRecord(r)
check: check:
r.seqNum == seqNum r.seqNum == seqNum
typedRecord.id == id typedRecord.id == id
@ -77,25 +75,22 @@ suite "ENR encoding tests":
testRlpEncodingLoop(enr.value) testRlpEncodingLoop(enr.value)
test "Empty RLP": test "Empty RLP":
expect ValueError: expect RlpError:
let _ = rlp.decode([], enr.Record) let _ = rlp.decode([], enr.Record)
var r: Record check Record.fromBytes([]).isErr()
check not fromBytes(r, [])
test "Invalid RLP": test "Invalid RLP":
expect RlpError: expect RlpError:
let _ = rlp.decode([byte 0xf7], enr.Record) let _ = rlp.decode([byte 0xf7], enr.Record)
var r: Record check Record.fromBytes([byte 0xf7]).isErr()
check not fromBytes(r, [byte 0xf7])
test "No RLP list": test "No RLP list":
expect ValueError: expect RlpError:
let _ = rlp.decode([byte 0x7f], enr.Record) let _ = rlp.decode([byte 0x7f], enr.Record)
var r: Record check Record.fromBytes([byte 0x7f]).isErr()
check not fromBytes(r, [byte 0x7f])
test "ENR with RLP list value": test "ENR with RLP list value":
type type
@ -121,15 +116,14 @@ suite "ENR encoding tests":
test "Base64 encode loop": test "Base64 encode loop":
const encodedBase64 = "-IS4QHCYrYZbAKWCBRlAy5zzaDZXJBGkcnh4MHcBFZntXNFrdvJjX04jRzjzCBOonrkTfj499SZuOh8R33Ls8RRcy5wBgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQPKY0yuDUmstAHYpMa2_oxVtw0RW_QAdpzBQA8yWM0xOIN1ZHCCdl8" const encodedBase64 = "-IS4QHCYrYZbAKWCBRlAy5zzaDZXJBGkcnh4MHcBFZntXNFrdvJjX04jRzjzCBOonrkTfj499SZuOh8R33Ls8RRcy5wBgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQPKY0yuDUmstAHYpMa2_oxVtw0RW_QAdpzBQA8yWM0xOIN1ZHCCdl8"
var r: Record let res = Record.fromBase64(encodedBase64)
check: check:
r.fromBase64(encodedBase64) res.isOk()
toBase64(r) == encodedBase64 toBase64(res.value) == encodedBase64
test "Invalid base64": test "Invalid base64":
var r: Record let res = Record.fromBase64("-IS4QHCYrYZbAKWCBRlAy5zzaDZXJBGkcnhMHcBFZntXNFrdv*jX04jRzjzCBOonrkTfj499SZuOh8R33Ls8RRcy5wBgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQPKY0yuDUmstAHYpMa2_oxVtw0RW_QAdpzBQA8yWM0xOIN1ZHCCdl8")
let valid = r.fromBase64("-IS4QHCYrYZbAKWCBRlAy5zzaDZXJBGkcnhMHcBFZntXNFrdv*jX04jRzjzCBOonrkTfj499SZuOh8R33Ls8RRcy5wBgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQPKY0yuDUmstAHYpMa2_oxVtw0RW_QAdpzBQA8yWM0xOIN1ZHCCdl8") check res.isErr()
check not valid
test "URI encode loop": test "URI encode loop":
let let
@ -141,20 +135,16 @@ suite "ENR encoding tests":
check res.isOk() check res.isOk()
let enr = res.value() let enr = res.value()
let uri = enr.toURI() let uri = enr.toURI()
var enr2: Record let res2 = Record.fromURI(uri)
let valid = enr2.fromURI(uri) check:
check(valid) res2.isOk()
check(enr == enr2) enr == res2.value()
test "Invalid URI: empty": test "Invalid URI: empty":
var r: Record check Record.fromURI("").isErr()
let valid = r.fromURI("")
check not valid
test "Invalid URI: no payload": test "Invalid URI: no payload":
var r: Record check Record.fromURI("enr:").isErr()
let valid = r.fromURI("enr:")
check not valid
suite "ENR init tests": suite "ENR init tests":
test "Record.init minimum fields": test "Record.init minimum fields":
@ -163,7 +153,7 @@ suite "ENR init tests":
port = Opt.none(Port) port = Opt.none(Port)
enr = Record.init( enr = Record.init(
100, keypair.seckey, Opt.none(IpAddress), port, port)[] 100, keypair.seckey, Opt.none(IpAddress), port, port)[]
typedEnr = get enr.toTypedRecord() typedEnr = TypedRecord.fromRecord(enr)
check: check:
testRlpEncodingLoop(enr) testRlpEncodingLoop(enr)
@ -186,7 +176,7 @@ suite "ENR init tests":
port = Opt.some(Port(9000)) port = Opt.some(Port(9000))
enr = Record.init( enr = Record.init(
100, keypair.seckey, Opt.some(ip), port, port)[] 100, keypair.seckey, Opt.some(ip), port, port)[]
typedEnr = get enr.toTypedRecord() typedEnr = TypedRecord.fromRecord(enr)
check: check:
typedEnr.ip.isSome() typedEnr.ip.isSome()
@ -205,7 +195,7 @@ suite "ENR init tests":
port = Opt.some(Port(9000)) port = Opt.some(Port(9000))
enr = Record.init( enr = Record.init(
100, keypair.seckey, Opt.some(ip), port, port)[] 100, keypair.seckey, Opt.some(ip), port, port)[]
typedEnr = get enr.toTypedRecord() typedEnr = TypedRecord.fromRecord(enr)
check: check:
typedEnr.ip.isNone() typedEnr.ip.isNone()
@ -400,7 +390,7 @@ suite "ENR update tests":
Opt.some(Port(9000)), Opt.some(Port(9000))) Opt.some(Port(9000)), Opt.some(Port(9000)))
check updated.isOk() check updated.isOk()
let typedEnr = r.toTypedRecord().get() let typedEnr = TypedRecord.fromRecord(r)
check: check:
typedEnr.ip.isSome() typedEnr.ip.isSome()
@ -419,7 +409,7 @@ suite "ENR update tests":
Opt.some(Port(9001)), Opt.some(Port(9001))) Opt.some(Port(9001)), Opt.some(Port(9001)))
check updated.isOk() check updated.isOk()
let typedEnr = r.toTypedRecord().get() let typedEnr = TypedRecord.fromRecord(r)
check: check:
typedEnr.ip.isSome() typedEnr.ip.isSome()

View File

@ -117,25 +117,28 @@ type
name: "node" .}: Node name: "node" .}: Node
proc parseCmdArg*(T: type enr.Record, p: string): T {.raises: [ValueError].} = proc parseCmdArg*(T: type enr.Record, p: string): T {.raises: [ValueError].} =
if not fromURI(result, p): let res = enr.Record.fromURI(p)
raise newException(ValueError, "Invalid ENR") if res.isErr:
raise newException(ValueError, "Invalid ENR:" & $res.error)
res.value
proc completeCmdArg*(T: type enr.Record, val: string): seq[string] = proc completeCmdArg*(T: type enr.Record, val: string): seq[string] =
return @[] return @[]
proc parseCmdArg*(T: type Node, p: string): T {.raises: [ValueError].} = proc parseCmdArg*(T: type Node, p: string): T {.raises: [ValueError].} =
var record: enr.Record let res = enr.Record.fromURI(p)
if not fromURI(record, p): if res.isErr:
raise newException(ValueError, "Invalid ENR") raise newException(ValueError, "Invalid ENR:" & $res.error)
let n = newNode(record) let n = newNode(res.value)
if n.isErr: if n.isErr:
raise newException(ValueError, $n.error) raise newException(ValueError, $n.error)
if n[].address.isNone(): if n.value.address.isNone():
raise newException(ValueError, "ENR without address") raise newException(ValueError, "ENR without address")
n[] n.value
proc completeCmdArg*(T: type Node, val: string): seq[string] = proc completeCmdArg*(T: type Node, val: string): seq[string] =
return @[] return @[]