discv5: Address review comments

This commit is contained in:
kdeme 2020-05-01 22:34:26 +02:00
parent 74df90e16d
commit 887cbba563
No known key found for this signature in database
GPG Key ID: 4E8DD21420AF43F5
4 changed files with 50 additions and 53 deletions

View File

@ -42,19 +42,19 @@ type
response*: seq[byte] response*: seq[byte]
DecodeError* = enum DecodeError* = enum
HandshakeError, HandshakeError = "discv5: handshake failed"
PacketError, PacketError = "discv5: invalid packet",
DecryptError, DecryptError = "discv5: decryption failed",
UnsupportedMessage UnsupportedMessage = "discv5: unsupported message"
DecodeResult*[T] = Result[T, DecodeError] DecodeResult*[T] = Result[T, DecodeError]
EncodeResult*[T] = Result[T, cstring] EncodeResult*[T] = Result[T, cstring]
proc mapErrTo[T, E](r: Result[T, E], v: static DecodeError): proc mapErrTo[T, E](r: Result[T, E], v: static DecodeError):
DecodeResult[T] {.raises:[].} = DecodeResult[T] =
r.mapErr(proc (e: E): DecodeError = v) r.mapErr(proc (e: E): DecodeError = v)
proc idNonceHash(nonce, ephkey: openarray[byte]): MDigest[256] {.raises:[].} = proc idNonceHash(nonce, ephkey: openarray[byte]): MDigest[256] =
var ctx: sha256 var ctx: sha256
ctx.init() ctx.init()
ctx.update(idNoncePrefix) ctx.update(idNoncePrefix)
@ -81,8 +81,7 @@ proc deriveKeys(n1, n2: NodeID, priv: PrivateKey, pub: PublicKey,
hkdf(sha256, eph.data, idNonce, info, toOpenArray(res, 0, sizeof(secrets) - 1)) hkdf(sha256, eph.data, idNonce, info, toOpenArray(res, 0, sizeof(secrets) - 1))
ok(secrets) ok(secrets)
proc encryptGCM*(key, nonce, pt, authData: openarray[byte]): proc encryptGCM*(key, nonce, pt, authData: openarray[byte]): seq[byte] =
seq[byte] {.raises:[].} =
var ectx: GCM[aes128] var ectx: GCM[aes128]
ectx.init(key, nonce, authData) ectx.init(key, nonce, authData)
result = newSeq[byte](pt.len + gcmTagSize) result = newSeq[byte](pt.len + gcmTagSize)
@ -93,9 +92,8 @@ proc encryptGCM*(key, nonce, pt, authData: openarray[byte]):
proc encodeAuthHeader(c: Codec, proc encodeAuthHeader(c: Codec,
toId: NodeID, toId: NodeID,
nonce: array[gcmNonceSize, byte], nonce: array[gcmNonceSize, byte],
handshakeSecrets: var HandshakeSecrets,
challenge: Whoareyou): challenge: Whoareyou):
EncodeResult[seq[byte]] = EncodeResult[(seq[byte], HandshakeSecrets)] =
var resp = AuthResponse(version: 5) var resp = AuthResponse(version: 5)
let ln = c.localNode let ln = c.localNode
@ -108,24 +106,24 @@ proc encodeAuthHeader(c: Codec,
ephKeys.pubkey.toRaw) ephKeys.pubkey.toRaw)
resp.signature = signature.toRaw resp.signature = signature.toRaw
handshakeSecrets = ? deriveKeys(ln.id, toId, ephKeys.seckey, challenge.pubKey, let secrets = ? deriveKeys(ln.id, toId, ephKeys.seckey, challenge.pubKey,
challenge.idNonce) challenge.idNonce)
let respRlp = rlp.encode(resp) let respRlp = rlp.encode(resp)
var zeroNonce: array[gcmNonceSize, byte] var zeroNonce: array[gcmNonceSize, byte]
let respEnc = encryptGCM(handshakeSecrets.authRespKey, zeroNonce, respRLP, []) let respEnc = encryptGCM(secrets.authRespKey, zeroNonce, respRLP, [])
let header = AuthHeader(auth: nonce, idNonce: challenge.idNonce, let header = AuthHeader(auth: nonce, idNonce: challenge.idNonce,
scheme: authSchemeName, ephemeralKey: ephKeys.pubkey.toRaw, scheme: authSchemeName, ephemeralKey: ephKeys.pubkey.toRaw,
response: respEnc) response: respEnc)
ok(rlp.encode(header)) ok((rlp.encode(header), secrets))
proc `xor`[N: static[int], T](a, b: array[N, T]): array[N, T] {.raises:[].} = proc `xor`[N: static[int], T](a, b: array[N, T]): array[N, T] =
for i in 0 .. a.high: for i in 0 .. a.high:
result[i] = a[i] xor b[i] result[i] = a[i] xor b[i]
proc packetTag(destNode, srcNode: NodeID): PacketTag {.raises:[].} = proc packetTag(destNode, srcNode: NodeID): PacketTag =
let let
destId = destNode.toByteArrayBE() destId = destNode.toByteArrayBE()
srcId = srcNode.toByteArrayBE() srcId = srcNode.toByteArrayBE()
@ -135,7 +133,7 @@ proc packetTag(destNode, srcNode: NodeID): PacketTag {.raises:[].} =
proc encodePacket*(c: Codec, proc encodePacket*(c: Codec,
toId: NodeID, toId: NodeID,
toAddr: Address, toAddr: Address,
message: seq[byte], message: openarray[byte],
challenge: Whoareyou): challenge: Whoareyou):
EncodeResult[(seq[byte], array[gcmNonceSize, byte])] = EncodeResult[(seq[byte], array[gcmNonceSize, byte])] =
var nonce: array[gcmNonceSize, byte] var nonce: array[gcmNonceSize, byte]
@ -154,12 +152,12 @@ proc encodePacket*(c: Codec,
# yet. That's fine, we will be responded with whoareyou. # yet. That's fine, we will be responded with whoareyou.
discard c.db.loadKeys(toId, toAddr, readKey, writeKey) discard c.db.loadKeys(toId, toAddr, readKey, writeKey)
else: else:
var sec: HandshakeSecrets var secrets: HandshakeSecrets
headEnc = ? c.encodeAuthHeader(toId, nonce, sec, challenge) (headEnc, secrets) = ? c.encodeAuthHeader(toId, nonce, challenge)
writeKey = sec.writeKey writeKey = secrets.writeKey
# TODO: is it safe to ignore the error here? # TODO: is it safe to ignore the error here?
discard c.db.storeKeys(toId, toAddr, sec.readKey, sec.writeKey) discard c.db.storeKeys(toId, toAddr, secrets.readKey, secrets.writeKey)
let tag = packetTag(toId, c.localNode.id) let tag = packetTag(toId, c.localNode.id)
@ -335,18 +333,17 @@ proc decodePacket*(c: var Codec,
decodeMessage(message.get()) decodeMessage(message.get())
proc newRequestId*(): Result[RequestId, cstring] {.raises:[].} = proc newRequestId*(): Result[RequestId, cstring] =
var id: RequestId var id: RequestId
if randomBytes(addr id, sizeof(id)) != sizeof(id): if randomBytes(addr id, sizeof(id)) != sizeof(id):
err("Could not randomize bytes") err("Could not randomize bytes")
else: else:
ok(id) ok(id)
proc numFields(T: typedesc): int {.raises:[].} = proc numFields(T: typedesc): int =
for k, v in fieldPairs(default(T)): inc result for k, v in fieldPairs(default(T)): inc result
proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId): proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId): seq[byte] =
seq[byte] {.raises:[].} =
result = newSeqOfCap[byte](64) result = newSeqOfCap[byte](64)
result.add(messageKind(T).ord) result.add(messageKind(T).ord)

View File

@ -80,7 +80,7 @@ proc makeEnrAux(seqNum: uint64, pk: PrivateKey,
cmp(a[0], b[0]) cmp(a[0], b[0])
proc append(w: var RlpWriter, seqNum: uint64, proc append(w: var RlpWriter, seqNum: uint64,
pairs: openarray[(string, Field)]): seq[byte] {.raises: [].} = pairs: openarray[(string, Field)]): seq[byte] =
w.append(seqNum) w.append(seqNum)
for (k, v) in pairs: for (k, v) in pairs:
w.append(k) w.append(k)
@ -139,7 +139,7 @@ proc init*(T: type Record, seqNum: uint64,
fields.add extraFields fields.add extraFields
makeEnrAux(seqNum, pk, fields) makeEnrAux(seqNum, pk, fields)
proc getField(r: Record, name: string, field: var Field): bool {.raises: [].} = proc getField(r: Record, name: string, field: var Field): bool =
# It might be more correct to do binary search, # It might be more correct to do binary search,
# as the fields are sorted, but it's unlikely to # as the fields are sorted, but it's unlikely to
# make any difference in reality. # make any difference in reality.
@ -152,7 +152,7 @@ proc requireKind(f: Field, kind: FieldKind) {.raises: [ValueError].} =
if f.kind != kind: if f.kind != kind:
raise newException(ValueError, "Wrong field kind") raise newException(ValueError, "Wrong field kind")
proc get*(r: Record, key: string, T: type): T {.raises: [ValueError].} = proc get*(r: Record, key: string, T: type): T {.raises: [ValueError, Defect].} =
var f: Field var f: Field
if r.getField(key, f): if r.getField(key, f):
when T is SomeInteger: when T is SomeInteger:
@ -183,14 +183,14 @@ proc get*(r: Record, key: string, T: type): T {.raises: [ValueError].} =
else: else:
raise newException(KeyError, "Key not found in ENR: " & key) raise newException(KeyError, "Key not found in ENR: " & key)
proc get*(r: Record, T: type PublicKey): Option[T] {.raises: [Defect].} = proc get*(r: Record, T: type PublicKey): Option[T] =
var pubkeyField: Field var pubkeyField: Field
if r.getField("secp256k1", pubkeyField) and pubkeyField.kind == kBytes: if r.getField("secp256k1", pubkeyField) and pubkeyField.kind == kBytes:
let pk = PublicKey.fromRaw(pubkeyField.bytes) let pk = PublicKey.fromRaw(pubkeyField.bytes)
if pk.isOk: if pk.isOk:
return some pk[] return some pk[]
proc tryGet*(r: Record, key: string, T: type): Option[T] {.raises: [].} = proc tryGet*(r: Record, key: string, T: type): Option[T] =
try: try:
return some get(r, key, T) return some get(r, key, T)
except CatchableError: except CatchableError:
@ -312,12 +312,12 @@ proc fromURI*(r: var Record, s: string): bool =
template fromURI*(r: var Record, url: EnrUri): bool = template fromURI*(r: var Record, url: EnrUri): bool =
fromURI(r, string(url)) fromURI(r, string(url))
proc toBase64*(r: Record): string {.raises: [].} = proc toBase64*(r: Record): string =
result = Base64Url.encode(r.raw) result = Base64Url.encode(r.raw)
proc toURI*(r: Record): string {.raises: [].} = "enr:" & r.toBase64 proc toURI*(r: Record): string = "enr:" & r.toBase64
proc `$`(f: Field): string {.raises: [].} = proc `$`(f: Field): string =
case f.kind case f.kind
of kNum: of kNum:
$f.num $f.num
@ -326,7 +326,7 @@ proc `$`(f: Field): string {.raises: [].} =
of kString: of kString:
"\"" & f.str & "\"" "\"" & f.str & "\""
proc `$`*(r: Record): string {.raises: [].} = proc `$`*(r: Record): string =
result = "(" result = "("
var first = true var first = true
for (k, v) in r.pairs: for (k, v) in r.pairs:
@ -347,5 +347,5 @@ proc read*(rlp: var Rlp, T: typedesc[Record]):
raise newException(ValueError, "Could not deserialize") raise newException(ValueError, "Could not deserialize")
rlp.skipElem() rlp.skipElem()
proc append*(rlpWriter: var RlpWriter, value: Record) {.raises: [].} = proc append*(rlpWriter: var RlpWriter, value: Record) =
rlpWriter.appendRawBytes(value.raw) rlpWriter.appendRawBytes(value.raw)

View File

@ -48,16 +48,16 @@ proc newNode*(r: Record): Node =
record: r) record: r)
proc hash*(n: Node): hashes.Hash = hash(n.node.pubkey.toRaw) proc hash*(n: Node): hashes.Hash = hash(n.node.pubkey.toRaw)
proc `==`*(a, b: Node): bool {.raises: [].} = proc `==`*(a, b: Node): bool =
(a.isNil and b.isNil) or (a.isNil and b.isNil) or
(not a.isNil and not b.isNil and a.node.pubkey == b.node.pubkey) (not a.isNil and not b.isNil and a.node.pubkey == b.node.pubkey)
proc address*(n: Node): Address {.inline, raises: [].} = n.node.address proc address*(n: Node): Address {.inline.} = n.node.address
proc updateEndpoint*(n: Node, a: Address) {.inline, raises: [].} = proc updateEndpoint*(n: Node, a: Address) {.inline.} =
n.node.address = a n.node.address = a
proc `$`*(n: Node): string {.raises: [].} = proc `$`*(n: Node): string =
if n == nil: if n == nil:
"Node[local]" "Node[local]"
else: else:

View File

@ -21,11 +21,11 @@ const
BITS_PER_HOP = 8 BITS_PER_HOP = 8
ID_SIZE = 256 ID_SIZE = 256
proc distanceTo(n: Node, id: NodeId): UInt256 {.raises: [].} = proc distanceTo(n: Node, id: NodeId): UInt256 =
## Calculate the distance to a NodeId. ## Calculate the distance to a NodeId.
n.id xor id n.id xor id
proc logDist*(a, b: NodeId): uint32 {.raises: [].} = proc logDist*(a, b: NodeId): uint32 =
## Calculate the logarithmic distance between two `NodeId`s. ## Calculate the logarithmic distance between two `NodeId`s.
## ##
## According the specification, this is the log base 2 of the distance. But it ## According the specification, this is the log base 2 of the distance. But it
@ -44,7 +44,7 @@ proc logDist*(a, b: NodeId): uint32 {.raises: [].} =
break break
return uint32(a.len * 8 - lz) return uint32(a.len * 8 - lz)
proc newKBucket(istart, iend: NodeId): KBucket {.raises: [].} = proc newKBucket(istart, iend: NodeId): KBucket =
result.new() result.new()
result.istart = istart result.istart = istart
result.iend = iend result.iend = iend
@ -55,13 +55,13 @@ proc midpoint(k: KBucket): NodeId =
k.istart + (k.iend - k.istart) div 2.u256 k.istart + (k.iend - k.istart) div 2.u256
proc distanceTo(k: KBucket, id: NodeId): UInt256 = k.midpoint xor id proc distanceTo(k: KBucket, id: NodeId): UInt256 = k.midpoint xor id
proc nodesByDistanceTo(k: KBucket, id: NodeId): seq[Node] {.raises: [].} = proc nodesByDistanceTo(k: KBucket, id: NodeId): seq[Node] =
sortedByIt(k.nodes, it.distanceTo(id)) sortedByIt(k.nodes, it.distanceTo(id))
proc len(k: KBucket): int {.inline, raises: [].} = k.nodes.len proc len(k: KBucket): int {.inline.} = k.nodes.len
proc head(k: KBucket): Node {.inline, raises: [].} = k.nodes[0] proc head(k: KBucket): Node {.inline.} = k.nodes[0]
proc add(k: KBucket, n: Node): Node {.raises: [].} = proc add(k: KBucket, n: Node): Node =
## Try to add the given node to this bucket. ## Try to add the given node to this bucket.
## If the node is already present, it is moved to the tail of the list, and we return nil. ## If the node is already present, it is moved to the tail of the list, and we return nil.
@ -84,7 +84,7 @@ proc add(k: KBucket, n: Node): Node {.raises: [].} =
return k.head return k.head
return nil return nil
proc removeNode(k: KBucket, n: Node) {.raises: [].} = proc removeNode(k: KBucket, n: Node) =
let i = k.nodes.find(n) let i = k.nodes.find(n)
if i != -1: k.nodes.delete(i) if i != -1: k.nodes.delete(i)
@ -100,10 +100,10 @@ proc split(k: KBucket): tuple[lower, upper: KBucket] =
let bucket = if node.id <= splitid: result.lower else: result.upper let bucket = if node.id <= splitid: result.lower else: result.upper
bucket.replacementCache.add(node) bucket.replacementCache.add(node)
proc inRange(k: KBucket, n: Node): bool {.inline, raises: [].} = proc inRange(k: KBucket, n: Node): bool {.inline.} =
k.istart <= n.id and n.id <= k.iend k.istart <= n.id and n.id <= k.iend
proc contains(k: KBucket, n: Node): bool {.raises: [].} = n in k.nodes proc contains(k: KBucket, n: Node): bool = n in k.nodes
proc binaryGetBucketForNode(buckets: openarray[KBucket], proc binaryGetBucketForNode(buckets: openarray[KBucket],
id: NodeId): KBucket {.inline.} = id: NodeId): KBucket {.inline.} =
@ -140,7 +140,7 @@ proc computeSharedPrefixBits(nodes: openarray[Node]): int =
doAssert(false, "Unable to calculate number of shared prefix bits") doAssert(false, "Unable to calculate number of shared prefix bits")
proc init*(r: var RoutingTable, thisNode: Node) {.inline, raises: [].} = proc init*(r: var RoutingTable, thisNode: Node) {.inline.} =
r.thisNode = thisNode r.thisNode = thisNode
r.buckets = @[newKBucket(0.u256, high(Uint256))] r.buckets = @[newKBucket(0.u256, high(Uint256))]
randomize() # for later `randomNodes` selection randomize() # for later `randomNodes` selection
@ -204,7 +204,7 @@ proc neighbours*(r: RoutingTable, id: NodeId, k: int = BUCKET_SIZE): seq[Node] =
if result.len > k: if result.len > k:
result.setLen(k) result.setLen(k)
proc idAtDistance*(id: NodeId, dist: uint32): NodeId {.raises: [].} = proc idAtDistance*(id: NodeId, dist: uint32): NodeId =
## Calculate the "lowest" `NodeId` for given logarithmic distance. ## Calculate the "lowest" `NodeId` for given logarithmic distance.
## A logarithmic distance obviously covers a whole range of distances and thus ## A logarithmic distance obviously covers a whole range of distances and thus
## potential `NodeId`s. ## potential `NodeId`s.
@ -219,10 +219,10 @@ proc neighboursAtDistance*(r: RoutingTable, distance: uint32,
# that are exactly the requested distance. # that are exactly the requested distance.
keepIf(result, proc(n: Node): bool = logDist(n.id, r.thisNode.id) == distance) keepIf(result, proc(n: Node): bool = logDist(n.id, r.thisNode.id) == distance)
proc len*(r: RoutingTable): int {.raises: [].} = proc len*(r: RoutingTable): int =
for b in r.buckets: result += b.len for b in r.buckets: result += b.len
proc moveRight[T](arr: var openarray[T], a, b: int) {.inline, raises: [].} = proc moveRight[T](arr: var openarray[T], a, b: int) {.inline.} =
## In `arr` move elements in range [a, b] right by 1. ## In `arr` move elements in range [a, b] right by 1.
var t: T var t: T
shallowCopy(t, arr[b + 1]) shallowCopy(t, arr[b + 1])
@ -240,7 +240,7 @@ proc setJustSeen*(r: RoutingTable, n: Node) =
b.nodes[0] = n b.nodes[0] = n
b.lastUpdated = epochTime() b.lastUpdated = epochTime()
proc nodeToRevalidate*(r: RoutingTable): Node {.raises: [].} = proc nodeToRevalidate*(r: RoutingTable): Node =
var buckets = r.buckets var buckets = r.buckets
shuffle(buckets) shuffle(buckets)
# TODO: Should we prioritize less-recently-updated buckets instead? # TODO: Should we prioritize less-recently-updated buckets instead?