Proc to func in enr code and style guide related clean-up (#555)

Should not have any functional changes. Clean-up related to
avoiding result usage (also implicit), and other style
guide items.
This commit is contained in:
Kim De Mey 2022-11-15 10:34:56 +01:00 committed by GitHub
parent 522db295f2
commit 4b22fcdce4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 101 additions and 73 deletions

View File

@ -28,7 +28,6 @@ type
Record* = object Record* = object
seqNum*: uint64 seqNum*: uint64
# signature: seq[byte]
raw*: seq[byte] # RLP encoded record raw*: seq[byte] # RLP encoded record
pairs: seq[FieldPair] # sorted list of all key/value pairs pairs: seq[FieldPair] # sorted list of all key/value pairs
@ -78,25 +77,27 @@ template toField[T](v: T): Field =
else: else:
{.error: "Unsupported field type".} {.error: "Unsupported field type".}
proc `==`(a, b: Field): bool = func `==`(a, b: Field): bool =
if a.kind == b.kind: if a.kind == b.kind:
case a.kind case a.kind
of kString: of kString:
return a.str == b.str a.str == b.str
of kNum: of kNum:
return a.num == b.num a.num == b.num
of kBytes: of kBytes:
return a.bytes == b.bytes a.bytes == b.bytes
of kList: of kList:
return a.listRaw == b.listRaw a.listRaw == b.listRaw
else: else:
return false false
proc cmp(a, b: FieldPair): int = cmp(a[0], b[0]) func cmp(a, b: FieldPair): int = cmp(a[0], b[0])
proc makeEnrRaw(seqNum: uint64, pk: PrivateKey, func makeEnrRaw(
seqNum: uint64, pk: PrivateKey,
pairs: openArray[FieldPair]): EnrResult[seq[byte]] = pairs: openArray[FieldPair]): EnrResult[seq[byte]] =
proc append(w: var RlpWriter, seqNum: uint64, func append(
w: var RlpWriter, seqNum: uint64,
pairs: openArray[FieldPair]): seq[byte] = pairs: openArray[FieldPair]): seq[byte] =
w.append(seqNum) w.append(seqNum)
for (k, v) in pairs: for (k, v) in pairs:
@ -124,7 +125,8 @@ proc makeEnrRaw(seqNum: uint64, pk: PrivateKey,
else: else:
ok(raw) ok(raw)
proc makeEnrAux(seqNum: uint64, pk: PrivateKey, func makeEnrAux(
seqNum: uint64, pk: PrivateKey,
pairs: openArray[FieldPair]): EnrResult[Record] = pairs: openArray[FieldPair]): EnrResult[Record] =
var record: Record var record: Record
record.pairs = @pairs record.pairs = @pairs
@ -144,7 +146,8 @@ proc makeEnrAux(seqNum: uint64, pk: PrivateKey,
record.raw = ? makeEnrRaw(seqNum, pk, record.pairs) record.raw = ? makeEnrRaw(seqNum, pk, record.pairs)
ok(record) ok(record)
macro initRecord*(seqNum: uint64, pk: PrivateKey, macro initRecord*(
seqNum: uint64, pk: PrivateKey,
pairs: untyped{nkTableConstr}): untyped = pairs: untyped{nkTableConstr}): untyped =
## Initialize a `Record` with given sequence number, private key and k:v ## Initialize a `Record` with given sequence number, private key and k:v
## pairs. ## pairs.
@ -160,7 +163,9 @@ macro initRecord*(seqNum: uint64, pk: PrivateKey,
template toFieldPair*(key: string, value: auto): FieldPair = template toFieldPair*(key: string, value: auto): FieldPair =
(key, toField(value)) (key, toField(value))
proc addAddress(fields: var seq[FieldPair], ip: Option[ValidIpAddress], func addAddress(
fields: var seq[FieldPair],
ip: Option[ValidIpAddress],
tcpPort, udpPort: Option[Port]) = tcpPort, udpPort: Option[Port]) =
## Add address information in new fields. Incomplete address ## Add address information in new fields. Incomplete address
## information is allowed (example: Port but not IP) as that information ## information is allowed (example: Port but not IP) as that information
@ -182,8 +187,9 @@ proc addAddress(fields: var seq[FieldPair], ip: Option[ValidIpAddress],
if udpPort.isSome(): if udpPort.isSome():
fields.add(("udp", udpPort.get().uint16.toField)) fields.add(("udp", udpPort.get().uint16.toField))
proc init*(T: type Record, seqNum: uint64, func init*(
pk: PrivateKey, T: type Record,
seqNum: uint64, pk: PrivateKey,
ip: Option[ValidIpAddress], ip: Option[ValidIpAddress],
tcpPort, udpPort: Option[Port], tcpPort, udpPort: Option[Port],
extraFields: openArray[FieldPair] = []): extraFields: openArray[FieldPair] = []):
@ -199,7 +205,7 @@ proc init*(T: type Record, seqNum: uint64,
fields.add extraFields fields.add extraFields
makeEnrAux(seqNum, pk, fields) makeEnrAux(seqNum, pk, fields)
proc getField(r: Record, name: string, field: var Field): bool = func getField(r: Record, name: string, field: var Field): bool =
# It might be more correct to do binary search, # It might be more correct to do binary search,
# as the fields are sorted, but it's unlikely to # as the fields are sorted, but it's unlikely to
# make any difference in reality. # make any difference in reality.
@ -207,14 +213,15 @@ proc getField(r: Record, name: string, field: var Field): bool =
if k == name: if k == name:
field = v field = v
return true return true
false
proc requireKind(f: Field, kind: FieldKind): EnrResult[void] = func requireKind(f: Field, kind: FieldKind): EnrResult[void] =
if f.kind != kind: if f.kind != kind:
err("Wrong field kind") err("Wrong field kind")
else: else:
ok() ok()
proc get*(r: Record, key: string, T: type): EnrResult[T] = func get*(r: Record, key: string, T: type): EnrResult[T] =
## Get the value from the provided key. ## Get the value from the provided key.
var f: Field var f: Field
if r.getField(key, f): if r.getField(key, f):
@ -250,7 +257,7 @@ proc get*(r: Record, key: string, T: type): EnrResult[T] =
else: else:
err("Key not found in ENR") err("Key not found in ENR")
proc get*(r: Record, T: type PublicKey): Option[T] = func get*(r: Record, T: type PublicKey): Option[T] =
## Get the `PublicKey` from provided `Record`. Return `none` when there is ## Get the `PublicKey` from provided `Record`. Return `none` when there is
## no `PublicKey` in the record. ## no `PublicKey` in the record.
var pubkeyField: Field var pubkeyField: Field
@ -258,16 +265,19 @@ proc get*(r: Record, T: type PublicKey): Option[T] =
let pk = PublicKey.fromRaw(pubkeyField.bytes) let pk = PublicKey.fromRaw(pubkeyField.bytes)
if pk.isOk: if pk.isOk:
return some pk[] return some pk[]
none(T)
proc find(r: Record, key: string): Option[int] = func find(r: Record, key: string): Option[int] =
## Search for key in record key:value pairs. ## Search for key in record key:value pairs.
## ##
## Returns some(index of key) if key is found in record. Else return none. ## Returns some(index of key) if key is found in record. Else return none.
for i, (k, v) in r.pairs: for i, (k, v) in r.pairs:
if k == key: if k == key:
return some(i) return some(i)
none(int)
proc update*(record: var Record, pk: PrivateKey, func update*(
record: var Record, pk: PrivateKey,
fieldPairs: openArray[FieldPair]): EnrResult[void] = fieldPairs: openArray[FieldPair]): EnrResult[void] =
## Update a `Record` k:v pairs. ## Update a `Record` k:v pairs.
## ##
@ -308,7 +318,9 @@ proc update*(record: var Record, pk: PrivateKey,
ok() ok()
proc update*(r: var Record, pk: PrivateKey, func update*(
r: var Record,
pk: PrivateKey,
ip: Option[ValidIpAddress], ip: Option[ValidIpAddress],
tcpPort, udpPort: Option[Port] = none[Port](), tcpPort, udpPort: Option[Port] = none[Port](),
extraFields: openArray[FieldPair] = []): extraFields: openArray[FieldPair] = []):
@ -329,7 +341,7 @@ proc update*(r: var Record, pk: PrivateKey,
fields.add extraFields fields.add extraFields
r.update(pk, fields) r.update(pk, fields)
proc tryGet*(r: Record, key: string, T: type): Option[T] = func tryGet*(r: Record, key: string, T: type): Option[T] =
## Get the value from the provided key. ## Get the value from the provided key.
## Return `none` if the key does not exist or if the value is invalid ## Return `none` if the key does not exist or if the value is invalid
## according to type `T`. ## according to type `T`.
@ -339,7 +351,7 @@ proc tryGet*(r: Record, key: string, T: type): Option[T] =
else: else:
none(T) none(T)
proc toTypedRecord*(r: Record): EnrResult[TypedRecord] = func toTypedRecord*(r: Record): EnrResult[TypedRecord] =
let id = r.tryGet("id", string) let id = r.tryGet("id", string)
if id.isSome: if id.isSome:
var tr: TypedRecord var tr: TypedRecord
@ -360,24 +372,29 @@ proc toTypedRecord*(r: Record): EnrResult[TypedRecord] =
else: else:
err("Record without id field") err("Record without id field")
proc contains*(r: Record, fp: (string, seq[byte])): bool = func contains*(r: Record, fp: (string, seq[byte])): bool =
# TODO: use FieldPair for this, but that is a bit cumbersome. Perhaps the # TODO: use FieldPair for this, but that is a bit cumbersome. Perhaps the
# `get` call can be improved to make this easier. # `get` call can be improved to make this easier.
let field = r.tryGet(fp[0], seq[byte]) let field = r.tryGet(fp[0], seq[byte])
if field.isSome(): if field.isSome():
if field.get() == fp[1]: if field.get() == fp[1]:
return true return true
false
proc verifySignatureV4(r: Record, sigData: openArray[byte], content: seq[byte]): func verifySignatureV4(
bool = r: Record, sigData: openArray[byte], content: seq[byte]): bool =
let publicKey = r.get(PublicKey) let publicKey = r.get(PublicKey)
if publicKey.isSome: if publicKey.isNone():
let sig = SignatureNR.fromRaw(sigData) return false
if sig.isOk:
var h = keccak256.digest(content)
return verify(sig[], SkMessage(h.data), publicKey.get)
proc verifySignature(r: Record): bool {.raises: [RlpError, Defect].} = let sig = SignatureNR.fromRaw(sigData)
if sig.isOk():
var h = keccak256.digest(content)
verify(sig[], SkMessage(h.data), publicKey.get)
else:
false
func verifySignature(r: Record): bool {.raises: [RlpError, Defect].} =
var rlp = rlpFromBytes(r.raw) var rlp = rlpFromBytes(r.raw)
let sz = rlp.listLen let sz = rlp.listLen
if not rlp.enterList: if not rlp.enterList:
@ -395,12 +412,15 @@ proc verifySignature(r: Record): bool {.raises: [RlpError, Defect].} =
if r.getField("id", id) and id.kind == kString: if r.getField("id", id) and id.kind == kString:
case id.str case id.str
of "v4": of "v4":
result = verifySignatureV4(r, sigData, content) verifySignatureV4(r, sigData, content)
else: else:
# Unknown Identity Scheme # Unknown Identity Scheme
discard false
else:
# No Identity Scheme provided
false
proc fromBytesAux(r: var Record): bool {.raises: [RlpError, Defect].} = func fromBytesAux(r: var Record): bool {.raises: [RlpError, Defect].} =
if r.raw.len > maxEnrSize: if r.raw.len > maxEnrSize:
return false return false
@ -447,39 +467,41 @@ proc fromBytesAux(r: var Record): bool {.raises: [RlpError, Defect].} =
verifySignature(r) verifySignature(r)
proc fromBytes*(r: var Record, s: openArray[byte]): bool = func fromBytes*(r: var Record, s: openArray[byte]): bool =
## Loads ENR from rlp-encoded bytes, and validates the signature. ## Loads ENR from rlp-encoded bytes, and validates the signature.
r.raw = @s r.raw = @s
try: try:
result = fromBytesAux(r) fromBytesAux(r)
except RlpError: except RlpError:
discard false
proc fromBase64*(r: var Record, s: string): bool = func fromBase64*(r: var Record, s: string): bool =
## Loads ENR from base64-encoded rlp-encoded bytes, and validates the ## Loads ENR from base64-encoded rlp-encoded bytes, and validates the
## signature. ## signature.
try: try:
r.raw = Base64Url.decode(s) r.raw = Base64Url.decode(s)
result = fromBytesAux(r) fromBytesAux(r)
except RlpError, Base64Error: except RlpError, Base64Error:
discard false
proc fromURI*(r: var Record, s: string): bool = func fromURI*(r: var Record, s: string): bool =
## Loads ENR from its text encoding: base64-encoded rlp-encoded bytes, ## Loads ENR from its text encoding: base64-encoded rlp-encoded bytes,
## prefixed with "enr:". Validates the signature. ## prefixed with "enr:". Validates the signature.
const prefix = "enr:" const prefix = "enr:"
if s.startsWith(prefix): if s.startsWith(prefix):
result = r.fromBase64(s[prefix.len .. ^1]) r.fromBase64(s[prefix.len .. ^1])
else:
false
template fromURI*(r: var Record, url: EnrUri): bool = template fromURI*(r: var Record, url: EnrUri): bool =
fromURI(r, string(url)) fromURI(r, string(url))
proc toBase64*(r: Record): string = func toBase64*(r: Record): string =
result = Base64Url.encode(r.raw) Base64Url.encode(r.raw)
proc toURI*(r: Record): string = "enr:" & r.toBase64 func toURI*(r: Record): string = "enr:" & r.toBase64
proc `$`(f: Field): string = func `$`(f: Field): string =
case f.kind case f.kind
of kNum: of kNum:
$f.num $f.num
@ -493,42 +515,48 @@ proc `$`(f: Field): string =
func `$`*(fp: FieldPair): string = func `$`*(fp: FieldPair): string =
fp[0] & ":" & $fp[1] fp[0] & ":" & $fp[1]
proc `$`*(r: Record): string = func `$`*(r: Record): string =
result = "(" var res = "("
result &= $r.seqNum res &= $r.seqNum
for (k, v) in r.pairs: for (k, v) in r.pairs:
result &= ", " res &= ", "
result &= k res &= k
result &= ": " res &= ": "
# For IP addresses we print something prettier than the default kinds # For IP addresses we print something prettier than the default kinds
# Note: Could disallow for invalid IPs in ENR also. # Note: Could disallow for invalid IPs in ENR also.
if k == "ip": if k == "ip":
let ip = r.tryGet("ip", array[4, byte]) let ip = r.tryGet("ip", array[4, byte])
if ip.isSome(): if ip.isSome():
result &= $ipv4(ip.get()) res &= $ipv4(ip.get())
else: else:
result &= "(Invalid) " & $v res &= "(Invalid) " & $v
elif k == "ip6": elif k == "ip6":
let ip = r.tryGet("ip6", array[16, byte]) let ip = r.tryGet("ip6", array[16, byte])
if ip.isSome(): if ip.isSome():
result &= $ipv6(ip.get()) res &= $ipv6(ip.get())
else: else:
result &= "(Invalid) " & $v res &= "(Invalid) " & $v
else: else:
result &= $v res &= $v
result &= ')' res &= ')'
proc `==`*(a, b: Record): bool = a.raw == b.raw res
proc read*(rlp: var Rlp, T: typedesc[Record]): func `==`*(a, b: Record): bool = a.raw == b.raw
func read*(
rlp: var Rlp, T: type Record):
T {.raises: [RlpError, ValueError, Defect].} = T {.raises: [RlpError, ValueError, Defect].} =
if not rlp.hasData() or not result.fromBytes(rlp.rawData): var res: T
if not rlp.hasData() or not res.fromBytes(rlp.rawData):
# TODO: This could also just be an invalid signature, would be cleaner to # TODO: This could also just be an invalid signature, would be cleaner to
# split of RLP deserialisation errors from this. # split of RLP deserialisation errors from this.
raise newException(ValueError, "Could not deserialize") raise newException(ValueError, "Could not deserialize")
rlp.skipElem() rlp.skipElem()
proc append*(rlpWriter: var RlpWriter, value: Record) = res
func append*(rlpWriter: var RlpWriter, value: Record) =
rlpWriter.appendRawBytes(value.raw) rlpWriter.appendRawBytes(value.raw)
chronicles.formatIt(seq[FieldPair]): $it chronicles.formatIt(seq[FieldPair]): $it

View File

@ -44,7 +44,7 @@ suite "ENR":
var r: Record var r: Record
check not fromBytes(r, []) check not fromBytes(r, [])
test "Base64 deserialsation without data": test "Base64 deserialisation without data":
var r: Record var r: Record
let sigValid = r.fromURI("enr:") let sigValid = r.fromURI("enr:")
check(not sigValid) check(not sigValid)