discv5: Address review comments

This commit is contained in:
kdeme 2020-05-01 22:34:26 +02:00
parent 74df90e16d
commit 887cbba563
No known key found for this signature in database
GPG Key ID: 4E8DD21420AF43F5
4 changed files with 50 additions and 53 deletions

View File

@ -42,19 +42,19 @@ type
response*: seq[byte]
DecodeError* = enum
HandshakeError,
PacketError,
DecryptError,
UnsupportedMessage
HandshakeError = "discv5: handshake failed"
PacketError = "discv5: invalid packet",
DecryptError = "discv5: decryption failed",
UnsupportedMessage = "discv5: unsupported message"
DecodeResult*[T] = Result[T, DecodeError]
EncodeResult*[T] = Result[T, cstring]
proc mapErrTo[T, E](r: Result[T, E], v: static DecodeError):
DecodeResult[T] {.raises:[].} =
DecodeResult[T] =
r.mapErr(proc (e: E): DecodeError = v)
proc idNonceHash(nonce, ephkey: openarray[byte]): MDigest[256] {.raises:[].} =
proc idNonceHash(nonce, ephkey: openarray[byte]): MDigest[256] =
var ctx: sha256
ctx.init()
ctx.update(idNoncePrefix)
@ -81,8 +81,7 @@ proc deriveKeys(n1, n2: NodeID, priv: PrivateKey, pub: PublicKey,
hkdf(sha256, eph.data, idNonce, info, toOpenArray(res, 0, sizeof(secrets) - 1))
ok(secrets)
proc encryptGCM*(key, nonce, pt, authData: openarray[byte]):
seq[byte] {.raises:[].} =
proc encryptGCM*(key, nonce, pt, authData: openarray[byte]): seq[byte] =
var ectx: GCM[aes128]
ectx.init(key, nonce, authData)
result = newSeq[byte](pt.len + gcmTagSize)
@ -93,9 +92,8 @@ proc encryptGCM*(key, nonce, pt, authData: openarray[byte]):
proc encodeAuthHeader(c: Codec,
toId: NodeID,
nonce: array[gcmNonceSize, byte],
handshakeSecrets: var HandshakeSecrets,
challenge: Whoareyou):
EncodeResult[seq[byte]] =
EncodeResult[(seq[byte], HandshakeSecrets)] =
var resp = AuthResponse(version: 5)
let ln = c.localNode
@ -108,24 +106,24 @@ proc encodeAuthHeader(c: Codec,
ephKeys.pubkey.toRaw)
resp.signature = signature.toRaw
handshakeSecrets = ? deriveKeys(ln.id, toId, ephKeys.seckey, challenge.pubKey,
let secrets = ? deriveKeys(ln.id, toId, ephKeys.seckey, challenge.pubKey,
challenge.idNonce)
let respRlp = rlp.encode(resp)
var zeroNonce: array[gcmNonceSize, byte]
let respEnc = encryptGCM(handshakeSecrets.authRespKey, zeroNonce, respRLP, [])
let respEnc = encryptGCM(secrets.authRespKey, zeroNonce, respRLP, [])
let header = AuthHeader(auth: nonce, idNonce: challenge.idNonce,
scheme: authSchemeName, ephemeralKey: ephKeys.pubkey.toRaw,
response: respEnc)
ok(rlp.encode(header))
ok((rlp.encode(header), secrets))
proc `xor`[N: static[int], T](a, b: array[N, T]): array[N, T] {.raises:[].} =
proc `xor`[N: static[int], T](a, b: array[N, T]): array[N, T] =
for i in 0 .. a.high:
result[i] = a[i] xor b[i]
proc packetTag(destNode, srcNode: NodeID): PacketTag {.raises:[].} =
proc packetTag(destNode, srcNode: NodeID): PacketTag =
let
destId = destNode.toByteArrayBE()
srcId = srcNode.toByteArrayBE()
@ -135,7 +133,7 @@ proc packetTag(destNode, srcNode: NodeID): PacketTag {.raises:[].} =
proc encodePacket*(c: Codec,
toId: NodeID,
toAddr: Address,
message: seq[byte],
message: openarray[byte],
challenge: Whoareyou):
EncodeResult[(seq[byte], array[gcmNonceSize, byte])] =
var nonce: array[gcmNonceSize, byte]
@ -154,12 +152,12 @@ proc encodePacket*(c: Codec,
# yet. That's fine, we will be responded with whoareyou.
discard c.db.loadKeys(toId, toAddr, readKey, writeKey)
else:
var sec: HandshakeSecrets
headEnc = ? c.encodeAuthHeader(toId, nonce, sec, challenge)
var secrets: HandshakeSecrets
(headEnc, secrets) = ? c.encodeAuthHeader(toId, nonce, challenge)
writeKey = sec.writeKey
writeKey = secrets.writeKey
# TODO: is it safe to ignore the error here?
discard c.db.storeKeys(toId, toAddr, sec.readKey, sec.writeKey)
discard c.db.storeKeys(toId, toAddr, secrets.readKey, secrets.writeKey)
let tag = packetTag(toId, c.localNode.id)
@ -335,18 +333,17 @@ proc decodePacket*(c: var Codec,
decodeMessage(message.get())
proc newRequestId*(): Result[RequestId, cstring] {.raises:[].} =
proc newRequestId*(): Result[RequestId, cstring] =
var id: RequestId
if randomBytes(addr id, sizeof(id)) != sizeof(id):
err("Could not randomize bytes")
else:
ok(id)
proc numFields(T: typedesc): int {.raises:[].} =
proc numFields(T: typedesc): int =
for k, v in fieldPairs(default(T)): inc result
proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId):
seq[byte] {.raises:[].} =
proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId): seq[byte] =
result = newSeqOfCap[byte](64)
result.add(messageKind(T).ord)

View File

@ -80,7 +80,7 @@ proc makeEnrAux(seqNum: uint64, pk: PrivateKey,
cmp(a[0], b[0])
proc append(w: var RlpWriter, seqNum: uint64,
pairs: openarray[(string, Field)]): seq[byte] {.raises: [].} =
pairs: openarray[(string, Field)]): seq[byte] =
w.append(seqNum)
for (k, v) in pairs:
w.append(k)
@ -139,7 +139,7 @@ proc init*(T: type Record, seqNum: uint64,
fields.add extraFields
makeEnrAux(seqNum, pk, fields)
proc getField(r: Record, name: string, field: var Field): bool {.raises: [].} =
proc getField(r: Record, name: string, field: var Field): bool =
# It might be more correct to do binary search,
# as the fields are sorted, but it's unlikely to
# make any difference in reality.
@ -152,7 +152,7 @@ proc requireKind(f: Field, kind: FieldKind) {.raises: [ValueError].} =
if f.kind != kind:
raise newException(ValueError, "Wrong field kind")
proc get*(r: Record, key: string, T: type): T {.raises: [ValueError].} =
proc get*(r: Record, key: string, T: type): T {.raises: [ValueError, Defect].} =
var f: Field
if r.getField(key, f):
when T is SomeInteger:
@ -183,14 +183,14 @@ proc get*(r: Record, key: string, T: type): T {.raises: [ValueError].} =
else:
raise newException(KeyError, "Key not found in ENR: " & key)
proc get*(r: Record, T: type PublicKey): Option[T] {.raises: [Defect].} =
proc get*(r: Record, T: type PublicKey): Option[T] =
var pubkeyField: Field
if r.getField("secp256k1", pubkeyField) and pubkeyField.kind == kBytes:
let pk = PublicKey.fromRaw(pubkeyField.bytes)
if pk.isOk:
return some pk[]
proc tryGet*(r: Record, key: string, T: type): Option[T] {.raises: [].} =
proc tryGet*(r: Record, key: string, T: type): Option[T] =
try:
return some get(r, key, T)
except CatchableError:
@ -312,12 +312,12 @@ proc fromURI*(r: var Record, s: string): bool =
template fromURI*(r: var Record, url: EnrUri): bool =
fromURI(r, string(url))
proc toBase64*(r: Record): string {.raises: [].} =
proc toBase64*(r: Record): string =
result = Base64Url.encode(r.raw)
proc toURI*(r: Record): string {.raises: [].} = "enr:" & r.toBase64
proc toURI*(r: Record): string = "enr:" & r.toBase64
proc `$`(f: Field): string {.raises: [].} =
proc `$`(f: Field): string =
case f.kind
of kNum:
$f.num
@ -326,7 +326,7 @@ proc `$`(f: Field): string {.raises: [].} =
of kString:
"\"" & f.str & "\""
proc `$`*(r: Record): string {.raises: [].} =
proc `$`*(r: Record): string =
result = "("
var first = true
for (k, v) in r.pairs:
@ -347,5 +347,5 @@ proc read*(rlp: var Rlp, T: typedesc[Record]):
raise newException(ValueError, "Could not deserialize")
rlp.skipElem()
proc append*(rlpWriter: var RlpWriter, value: Record) {.raises: [].} =
proc append*(rlpWriter: var RlpWriter, value: Record) =
rlpWriter.appendRawBytes(value.raw)

View File

@ -48,16 +48,16 @@ proc newNode*(r: Record): Node =
record: r)
proc hash*(n: Node): hashes.Hash = hash(n.node.pubkey.toRaw)
proc `==`*(a, b: Node): bool {.raises: [].} =
proc `==`*(a, b: Node): bool =
(a.isNil and b.isNil) or
(not a.isNil and not b.isNil and a.node.pubkey == b.node.pubkey)
proc address*(n: Node): Address {.inline, raises: [].} = n.node.address
proc address*(n: Node): Address {.inline.} = n.node.address
proc updateEndpoint*(n: Node, a: Address) {.inline, raises: [].} =
proc updateEndpoint*(n: Node, a: Address) {.inline.} =
n.node.address = a
proc `$`*(n: Node): string {.raises: [].} =
proc `$`*(n: Node): string =
if n == nil:
"Node[local]"
else:

View File

@ -21,11 +21,11 @@ const
BITS_PER_HOP = 8
ID_SIZE = 256
proc distanceTo(n: Node, id: NodeId): UInt256 {.raises: [].} =
proc distanceTo(n: Node, id: NodeId): UInt256 =
## Calculate the distance to a NodeId.
n.id xor id
proc logDist*(a, b: NodeId): uint32 {.raises: [].} =
proc logDist*(a, b: NodeId): uint32 =
## Calculate the logarithmic distance between two `NodeId`s.
##
## According the specification, this is the log base 2 of the distance. But it
@ -44,7 +44,7 @@ proc logDist*(a, b: NodeId): uint32 {.raises: [].} =
break
return uint32(a.len * 8 - lz)
proc newKBucket(istart, iend: NodeId): KBucket {.raises: [].} =
proc newKBucket(istart, iend: NodeId): KBucket =
result.new()
result.istart = istart
result.iend = iend
@ -55,13 +55,13 @@ proc midpoint(k: KBucket): NodeId =
k.istart + (k.iend - k.istart) div 2.u256
proc distanceTo(k: KBucket, id: NodeId): UInt256 = k.midpoint xor id
proc nodesByDistanceTo(k: KBucket, id: NodeId): seq[Node] {.raises: [].} =
proc nodesByDistanceTo(k: KBucket, id: NodeId): seq[Node] =
sortedByIt(k.nodes, it.distanceTo(id))
proc len(k: KBucket): int {.inline, raises: [].} = k.nodes.len
proc head(k: KBucket): Node {.inline, raises: [].} = k.nodes[0]
proc len(k: KBucket): int {.inline.} = k.nodes.len
proc head(k: KBucket): Node {.inline.} = k.nodes[0]
proc add(k: KBucket, n: Node): Node {.raises: [].} =
proc add(k: KBucket, n: Node): Node =
## Try to add the given node to this bucket.
## If the node is already present, it is moved to the tail of the list, and we return nil.
@ -84,7 +84,7 @@ proc add(k: KBucket, n: Node): Node {.raises: [].} =
return k.head
return nil
proc removeNode(k: KBucket, n: Node) {.raises: [].} =
proc removeNode(k: KBucket, n: Node) =
let i = k.nodes.find(n)
if i != -1: k.nodes.delete(i)
@ -100,10 +100,10 @@ proc split(k: KBucket): tuple[lower, upper: KBucket] =
let bucket = if node.id <= splitid: result.lower else: result.upper
bucket.replacementCache.add(node)
proc inRange(k: KBucket, n: Node): bool {.inline, raises: [].} =
proc inRange(k: KBucket, n: Node): bool {.inline.} =
k.istart <= n.id and n.id <= k.iend
proc contains(k: KBucket, n: Node): bool {.raises: [].} = n in k.nodes
proc contains(k: KBucket, n: Node): bool = n in k.nodes
proc binaryGetBucketForNode(buckets: openarray[KBucket],
id: NodeId): KBucket {.inline.} =
@ -140,7 +140,7 @@ proc computeSharedPrefixBits(nodes: openarray[Node]): int =
doAssert(false, "Unable to calculate number of shared prefix bits")
proc init*(r: var RoutingTable, thisNode: Node) {.inline, raises: [].} =
proc init*(r: var RoutingTable, thisNode: Node) {.inline.} =
r.thisNode = thisNode
r.buckets = @[newKBucket(0.u256, high(Uint256))]
randomize() # for later `randomNodes` selection
@ -204,7 +204,7 @@ proc neighbours*(r: RoutingTable, id: NodeId, k: int = BUCKET_SIZE): seq[Node] =
if result.len > k:
result.setLen(k)
proc idAtDistance*(id: NodeId, dist: uint32): NodeId {.raises: [].} =
proc idAtDistance*(id: NodeId, dist: uint32): NodeId =
## Calculate the "lowest" `NodeId` for given logarithmic distance.
## A logarithmic distance obviously covers a whole range of distances and thus
## potential `NodeId`s.
@ -219,10 +219,10 @@ proc neighboursAtDistance*(r: RoutingTable, distance: uint32,
# that are exactly the requested distance.
keepIf(result, proc(n: Node): bool = logDist(n.id, r.thisNode.id) == distance)
proc len*(r: RoutingTable): int {.raises: [].} =
proc len*(r: RoutingTable): int =
for b in r.buckets: result += b.len
proc moveRight[T](arr: var openarray[T], a, b: int) {.inline, raises: [].} =
proc moveRight[T](arr: var openarray[T], a, b: int) {.inline.} =
## In `arr` move elements in range [a, b] right by 1.
var t: T
shallowCopy(t, arr[b + 1])
@ -240,7 +240,7 @@ proc setJustSeen*(r: RoutingTable, n: Node) =
b.nodes[0] = n
b.lastUpdated = epochTime()
proc nodeToRevalidate*(r: RoutingTable): Node {.raises: [].} =
proc nodeToRevalidate*(r: RoutingTable): Node =
var buckets = r.buckets
shuffle(buckets)
# TODO: Should we prioritize less-recently-updated buckets instead?