Discv5 eh2 (#240)

* Discv5: More error handling improvements

- More results usage and raises pragma annotations
- Remove ENode related code and adjust Node object
- Misc.

* Add sendMessage and catch RlpError when decoding WhoAreYou

* Make the receive proc exception free

Except for `Exception` hah...

* Address review comments

* And another bunch of results and raises annotations

* Send Nodes Message also on 0 nodes and remove usage of broken require
This commit is contained in:
Kim De Mey 2020-05-28 10:19:36 +02:00 committed by GitHub
parent ff546d27c3
commit a110f091af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 381 additions and 267 deletions

View File

@ -1,6 +1,8 @@
import import
std/net, std/net, stint, stew/endians2,
eth/trie/db, types, ../enode eth/trie/db, types, node
{.push raises: [Defect].}
type type
DiscoveryDB* = ref object of Database DiscoveryDB* = ref object of Database
@ -19,16 +21,19 @@ const keySize = 1 + # unique triedb prefix (kNodeToKeys)
proc makeKey(id: NodeId, address: Address): array[keySize, byte] = proc makeKey(id: NodeId, address: Address): array[keySize, byte] =
result[0] = byte(kNodeToKeys) result[0] = byte(kNodeToKeys)
copyMem(addr result[1], unsafeAddr id, sizeof(id)) var pos = 1
result[pos ..< pos+sizeof(id)] = toBytes(id)
pos.inc(sizeof(id))
case address.ip.family case address.ip.family
of IpAddressFamily.IpV4: of IpAddressFamily.IpV4:
copyMem(addr result[sizeof(id) + 1], unsafeAddr address.ip.address_v4, sizeof(address.ip.address_v4)) result[pos ..< pos+sizeof(address.ip.address_v4)] = address.ip.address_v4
of IpAddressFamily.IpV6: of IpAddressFamily.IpV6:
copyMem(addr result[sizeof(id) + 1], unsafeAddr address.ip.address_v6, sizeof(address.ip.address_v6)) result[pos..< pos+sizeof(address.ip.address_v6)] = address.ip.address_v6
copyMem(addr result[sizeof(id) + 1 + sizeof(address.ip.address_v6)], unsafeAddr address.udpPort, sizeof(address.udpPort)) pos.inc(sizeof(address.ip.address_v6))
result[pos ..< pos+sizeof(address.port)] = toBytes(address.port.uint16)
method storeKeys*(db: DiscoveryDB, id: NodeId, address: Address, r, w: AesKey): method storeKeys*(db: DiscoveryDB, id: NodeId, address: Address, r, w: AesKey):
bool {.raises: [Defect].} = bool =
try: try:
var value: array[sizeof(r) + sizeof(w), byte] var value: array[sizeof(r) + sizeof(w), byte]
value[0 .. 15] = r value[0 .. 15] = r
@ -38,8 +43,8 @@ method storeKeys*(db: DiscoveryDB, id: NodeId, address: Address, r, w: AesKey):
except CatchableError: except CatchableError:
return false return false
method loadKeys*(db: DiscoveryDB, id: NodeId, address: Address, r, w: var AesKey): method loadKeys*(db: DiscoveryDB, id: NodeId, address: Address,
bool {.raises: [Defect].} = r, w: var AesKey): bool =
try: try:
let res = db.backend.get(makeKey(id, address)) let res = db.backend.get(makeKey(id, address))
if res.len != sizeof(r) + sizeof(w): if res.len != sizeof(r) + sizeof(w):
@ -50,8 +55,7 @@ method loadKeys*(db: DiscoveryDB, id: NodeId, address: Address, r, w: var AesKey
except CatchableError: except CatchableError:
return false return false
method deleteKeys*(db: DiscoveryDB, id: NodeId, address: Address): method deleteKeys*(db: DiscoveryDB, id: NodeId, address: Address): bool =
bool {.raises: [Defect].} =
try: try:
db.backend.del(makeKey(id, address)) db.backend.del(makeKey(id, address))
return true return true

View File

@ -1,6 +1,6 @@
import import
std/[tables, options], nimcrypto, stint, chronicles, stew/results, std/[tables, options], nimcrypto, stint, chronicles, stew/results,
types, node, enr, hkdf, ../enode, eth/[rlp, keys] types, node, enr, hkdf, eth/[rlp, keys]
export keys export keys
@ -43,8 +43,8 @@ type
DecodeError* = enum DecodeError* = enum
HandshakeError = "discv5: handshake failed" HandshakeError = "discv5: handshake failed"
PacketError = "discv5: invalid packet", PacketError = "discv5: invalid packet"
DecryptError = "discv5: decryption failed", DecryptError = "discv5: decryption failed"
UnsupportedMessage = "discv5: unsupported message" UnsupportedMessage = "discv5: unsupported message"
DecodeResult*[T] = Result[T, DecodeError] DecodeResult*[T] = Result[T, DecodeError]
@ -253,14 +253,9 @@ proc decodeAuthResp(c: Codec, fromId: NodeId, head: AuthHeader,
# 2. Should verify ENR and check for correct id in case an ENR is included # 2. Should verify ENR and check for correct id in case an ENR is included
# 3. Should verify id nonce signature # 3. Should verify id nonce signature
# More TODO: # Node returned might not have an address or not a valid address
# This will also not work if ENR does not contain an IP address or if the newNode = ? newNode(authResp.record).mapErrTo(HandshakeError)
# IP address is out of date and doesn't match current UDP end point
try:
newNode = newNode(authResp.record)
ok() ok()
except KeyError, ValueError:
err(HandshakeError)
proc decodePacket*(c: var Codec, proc decodePacket*(c: var Codec,
fromId: NodeID, fromId: NodeID,
@ -299,11 +294,6 @@ proc decodePacket*(c: var Codec,
c.handshakes.del(key) c.handshakes.del(key)
# For an incoming handshake, we are not sure the address in the ENR is there
# and if it is the real external IP, so we use the one we know from the
# UDP packet.
updateEndpoint(newNode, fromAddr)
# Swap keys to match remote # Swap keys to match remote
swap(sec.readKey, sec.writeKey) swap(sec.readKey, sec.writeKey)
# TODO: is it safe to ignore the error here? # TODO: is it safe to ignore the error here?

View File

@ -4,7 +4,7 @@
import import
net, strutils, macros, algorithm, options, net, strutils, macros, algorithm, options,
nimcrypto, stew/base64, nimcrypto, stew/base64,
eth/[rlp, keys], ../enode eth/[rlp, keys]
export options export options
@ -193,10 +193,10 @@ proc get*(r: Record, T: type PublicKey): Option[T] =
proc tryGet*(r: Record, key: string, T: type): Option[T] = proc tryGet*(r: Record, key: string, T: type): Option[T] =
try: try:
return some get(r, key, T) return some get(r, key, T)
except CatchableError: except ValueError:
discard discard
proc toTypedRecord*(r: Record): Option[TypedRecord] = proc toTypedRecord*(r: Record): EnrResult[TypedRecord] =
let id = r.tryGet("id", string) let id = r.tryGet("id", string)
if id.isSome: if id.isSome:
var tr: TypedRecord var tr: TypedRecord
@ -213,7 +213,9 @@ proc toTypedRecord*(r: Record): Option[TypedRecord] =
readField udp readField udp
readField udp6 readField udp6
return some(tr) ok(tr)
else:
err("Record without id field")
proc verifySignatureV4(r: Record, sigData: openarray[byte], content: seq[byte]): proc verifySignatureV4(r: Record, sigData: openarray[byte], content: seq[byte]):
bool = bool =
@ -290,7 +292,7 @@ proc fromBytes*(r: var Record, s: openarray[byte]): bool =
r.raw = @s r.raw = @s
try: try:
result = fromBytesAux(r) result = fromBytesAux(r)
except CatchableError: except RlpError:
discard discard
proc fromBase64*(r: var Record, s: string): bool = proc fromBase64*(r: var Record, s: string): bool =
@ -299,7 +301,7 @@ proc fromBase64*(r: var Record, s: string): bool =
try: try:
r.raw = Base64Url.decode(s) r.raw = Base64Url.decode(s)
result = fromBytesAux(r) result = fromBytesAux(r)
except CatchableError: except RlpError, Base64Error:
discard discard
proc fromURI*(r: var Record, s: string): bool = proc fromURI*(r: var Record, s: string): bool =
@ -344,6 +346,8 @@ proc `==`*(a, b: Record): bool = a.raw == b.raw
proc read*(rlp: var Rlp, T: typedesc[Record]): proc read*(rlp: var Rlp, T: typedesc[Record]):
T {.inline, raises:[RlpError, ValueError, Defect].} = T {.inline, raises:[RlpError, ValueError, Defect].} =
if not result.fromBytes(rlp.rawData): if not result.fromBytes(rlp.rawData):
# TODO: This could also just be an invalid signature, would be cleaner to
# split of RLP deserialisation errors from this.
raise newException(ValueError, "Could not deserialize") raise newException(ValueError, "Could not deserialize")
rlp.skipElem() rlp.skipElem()

View File

@ -1,64 +1,59 @@
import import
std/[net, hashes], nimcrypto, stint, chronicles, std/[net, hashes], nimcrypto, stint, chronos,
types, enr, eth/keys, ../enode eth/keys, enr
{.push raises: [Defect].} {.push raises: [Defect].}
type type
NodeId* = UInt256
Address* = object
ip*: IpAddress
port*: Port
Node* = ref object Node* = ref object
node*: ENode
id*: NodeId id*: NodeId
pubkey*: PublicKey
address*: Option[Address]
record*: Record record*: Record
proc toNodeId*(pk: PublicKey): NodeId = proc toNodeId*(pk: PublicKey): NodeId =
readUintBE[256](keccak256.digest(pk.toRaw()).data) readUintBE[256](keccak256.digest(pk.toRaw()).data)
# TODO: Lets not allow to create a node where enode info is not in sync with the proc newNode*(r: Record): Result[Node, cstring] =
# record
proc newNode*(enode: ENode, r: Record): Node =
Node(node: enode,
id: enode.pubkey.toNodeId(),
record: r)
proc newNode*(r: Record): Node =
# TODO: Handle IPv6 # TODO: Handle IPv6
var a: Address
try:
let
ipBytes = r.get("ip", array[4, byte])
udpPort = r.get("udp", uint16)
a = Address(ip: IpAddress(family: IpAddressFamily.IPv4,
address_v4: ipBytes),
udpPort: Port udpPort)
except KeyError, ValueError:
# TODO: This will result in a 0.0.0.0 address. Might introduce more bugs.
# Maybe we shouldn't allow the creation of Node from Record without IP.
# Will need some refactor though.
discard
let pk = r.get(PublicKey) let pk = r.get(PublicKey)
# This check is redundant as the deserialisation of `Record` will already fail
# at `verifySignature` if there is no public key
if pk.isNone(): if pk.isNone():
warn "Could not recover public key from ENR" return err("Could not recover public key from ENR")
return
let enode = ENode(pubkey: pk.get(), address: a) let tr = ? r.toTypedRecord()
result = Node(node: enode, if tr.ip.isSome() and tr.udp.isSome():
id: enode.pubkey.toNodeId(), let
record: r) ip = IpAddress(family: IpAddressFamily.IPv4, address_v4: tr.ip.get())
a = Address(ip: ip, port: Port(tr.udp.get()))
proc hash*(n: Node): hashes.Hash = hash(n.node.pubkey.toRaw) ok(Node(id: pk.get().toNodeId(), pubkey: pk.get() , record: r,
address: some(a)))
else:
ok(Node(id: pk.get().toNodeId(), pubkey: pk.get(), record: r,
address: none(Address)))
proc hash*(n: Node): hashes.Hash = hash(n.pubkey.toRaw)
proc `==`*(a, b: Node): bool = proc `==`*(a, b: Node): bool =
(a.isNil and b.isNil) or (a.isNil and b.isNil) or
(not a.isNil and not b.isNil and a.node.pubkey == b.node.pubkey) (not a.isNil and not b.isNil and a.pubkey == b.pubkey)
proc address*(n: Node): Address {.inline.} = n.node.address proc `$`*(a: Address): string =
result.add($a.ip)
proc updateEndpoint*(n: Node, a: Address) {.inline.} = result.add(":" & $a.port)
n.node.address = a
proc `$`*(n: Node): string = proc `$`*(n: Node): string =
if n == nil: if n == nil:
"Node[local]" "Node[uninitialized]"
elif n.address.isNone():
"Node[unaddressable]"
else: else:
"Node[" & $n.node.address.ip & ":" & $n.node.address.udpPort & "]" "Node[" & $n.address.get().ip & ":" & $n.address.get().port & "]"

View File

@ -76,12 +76,14 @@ import
std/[tables, sets, options, math, random], std/[tables, sets, options, math, random],
json_serialization/std/net, json_serialization/std/net,
stew/[byteutils, endians2], chronicles, chronos, stint, stew/[byteutils, endians2], chronicles, chronos, stint,
eth/[rlp, keys], ../enode, types, encoding, node, routing_table, enr eth/[rlp, keys], types, encoding, node, routing_table, enr
import nimcrypto except toHex import nimcrypto except toHex
export options export options
{.push raises: [Defect].}
logScope: logScope:
topics = "discv5" topics = "discv5"
@ -120,22 +122,24 @@ type
node: Node node: Node
message: seq[byte] message: seq[byte]
RandomSourceDepleted* = object of CatchableError DiscResult*[T] = Result[T, cstring]
proc addNode*(d: Protocol, node: Node) = proc addNode*(d: Protocol, node: Node): bool =
if node.address.isSome():
# Only add nodes with an address to the routing table
discard d.routingTable.addNode(node) discard d.routingTable.addNode(node)
return true
template addNode*(d: Protocol, enode: ENode) = proc addNode*(d: Protocol, r: Record): bool =
addNode d, newNode(enode) let node = newNode(r)
if node.isOk():
return d.addNode(node[])
template addNode*(d: Protocol, r: Record) = proc addNode*(d: Protocol, enr: EnrUri): bool =
addNode d, newNode(r)
proc addNode*(d: Protocol, enr: EnrUri) =
var r: Record var r: Record
let res = r.fromUri(enr) let res = r.fromUri(enr)
doAssert(res) if res:
d.addNode newNode(r) return d.addNode(r)
proc getNode*(d: Protocol, id: NodeId): Option[Node] = proc getNode*(d: Protocol, id: NodeId): Option[Node] =
d.routingTable.getNode(id) d.routingTable.getNode(id)
@ -152,15 +156,31 @@ func privKey*(d: Protocol): lent PrivateKey =
d.privateKey d.privateKey
proc send(d: Protocol, a: Address, data: seq[byte]) = proc send(d: Protocol, a: Address, data: seq[byte]) =
# debug "Sending bytes", amount = data.len, to = a let ta = initTAddress(a.ip, a.port)
let ta = initTAddress(a.ip, a.udpPort) try:
let f = d.transp.sendTo(ta, data) let f = d.transp.sendTo(ta, data)
f.callback = proc(data: pointer) {.gcsafe.} = f.callback = proc(data: pointer) {.gcsafe.} =
if f.failed: if f.failed:
# Could be `TransportUseClosedError` in case the transport is already
# closed, or could be `TransportOsError` in case of a socket error.
# In the latter case this would probably mostly occur if the network
# interface underneath gets disconnected or similar.
# TODO: Should this kind of error be propagated upwards? Probably, but
# it should not stop the process as that would reset the discovery
# progress in case there is even a small window of no connection.
# One case that needs this error available upwards is when revalidating
# nodes. Else the revalidation might end up clearing the routing tabl
# because of ping failures due to own network connection failure.
debug "Discovery send failed", msg = f.readError.msg debug "Discovery send failed", msg = f.readError.msg
except Exception as e:
# TODO: General exception still being raised from Chronos.
if e of Defect:
raise (ref Defect)(e)
else: doAssert(false)
proc send(d: Protocol, n: Node, data: seq[byte]) = proc send(d: Protocol, n: Node, data: seq[byte]) =
d.send(n.node.address, data) doAssert(n.address.isSome())
d.send(n.address.get(), data)
proc `xor`[N: static[int], T](a, b: array[N, T]): array[N, T] = proc `xor`[N: static[int], T](a, b: array[N, T]): array[N, T] =
for i in 0 .. a.high: for i in 0 .. a.high:
@ -177,16 +197,19 @@ proc isWhoAreYou(d: Protocol, packet: openArray[byte]): bool =
if packet.len > d.whoareyouMagic.len: if packet.len > d.whoareyouMagic.len:
result = d.whoareyouMagic == packet.toOpenArray(0, magicSize - 1) result = d.whoareyouMagic == packet.toOpenArray(0, magicSize - 1)
proc decodeWhoAreYou(d: Protocol, packet: openArray[byte]): Whoareyou = proc decodeWhoAreYou(d: Protocol, packet: openArray[byte]):
Whoareyou {.raises: [RlpError].} =
result = Whoareyou() result = Whoareyou()
result[] = rlp.decode(packet.toOpenArray(magicSize, packet.high), WhoareyouObj) result[] = rlp.decode(packet.toOpenArray(magicSize, packet.high), WhoareyouObj)
proc sendWhoareyou(d: Protocol, address: Address, toNode: NodeId, authTag: AuthTag) = proc sendWhoareyou(d: Protocol, address: Address, toNode: NodeId,
authTag: AuthTag): DiscResult[void] {.raises: [Exception, Defect].} =
trace "sending who are you", to = $toNode, toAddress = $address trace "sending who are you", to = $toNode, toAddress = $address
let challenge = Whoareyou(authTag: authTag, recordSeq: 0) let challenge = Whoareyou(authTag: authTag, recordSeq: 0)
if randomBytes(challenge.idNonce) != challenge.idNonce.len: if randomBytes(challenge.idNonce) != challenge.idNonce.len:
raise newException(RandomSourceDepleted, "Could not randomize bytes") return err("Could not randomize bytes")
# If there is already a handshake going on for this nodeid then we drop this # If there is already a handshake going on for this nodeid then we drop this
# new one. Handshake will get cleaned up after `handshakeTimeout`. # new one. Handshake will get cleaned up after `handshakeTimeout`.
# If instead overwriting the handshake would be allowed, the handshake timeout # If instead overwriting the handshake would be allowed, the handshake timeout
@ -195,9 +218,9 @@ proc sendWhoareyou(d: Protocol, address: Address, toNode: NodeId, authTag: AuthT
# a loop. # a loop.
# Use toNode + address to make it more difficult for an attacker to occupy # Use toNode + address to make it more difficult for an attacker to occupy
# the handshake of another node. # the handshake of another node.
let key = HandShakeKey(nodeId: toNode, address: $address) let key = HandShakeKey(nodeId: toNode, address: $address)
if not d.codec.handshakes.hasKeyOrPut(key, challenge): if not d.codec.handshakes.hasKeyOrPut(key, challenge):
# TODO: raises: [Exception]
sleepAsync(handshakeTimeout).addCallback() do(data: pointer): sleepAsync(handshakeTimeout).addCallback() do(data: pointer):
# TODO: should we still provide cancellation in case handshake completes # TODO: should we still provide cancellation in case handshake completes
# correctly? # correctly?
@ -206,43 +229,61 @@ proc sendWhoareyou(d: Protocol, address: Address, toNode: NodeId, authTag: AuthT
var data = @(whoareyouMagic(toNode)) var data = @(whoareyouMagic(toNode))
data.add(rlp.encode(challenge[])) data.add(rlp.encode(challenge[]))
d.send(address, data) d.send(address, data)
ok()
else:
err("NodeId already has ongoing handshake")
proc sendNodes(d: Protocol, toId: NodeId, toAddr: Address, reqId: RequestId, proc sendNodes(d: Protocol, toId: NodeId, toAddr: Address, reqId: RequestId,
nodes: openarray[Node]) = nodes: openarray[Node]): DiscResult[void] =
proc sendNodes(d: Protocol, toId: NodeId, toAddr: Address, proc sendNodes(d: Protocol, toId: NodeId, toAddr: Address,
message: NodesMessage, reqId: RequestId) {.nimcall.} = message: NodesMessage, reqId: RequestId): DiscResult[void] {.nimcall.} =
let (data, _) = d.codec.encodePacket(toId, toAddr, let (data, _) = ? d.codec.encodePacket(toId, toAddr,
encodeMessage(message, reqId), challenge = nil).tryGet() encodeMessage(message, reqId), challenge = nil)
d.send(toAddr, data) d.send(toAddr, data)
ok()
if nodes.len == 0:
# In case of 0 nodes, a reply is still needed
return d.sendNodes(toId, toAddr, NodesMessage(total: 1, enrs: @[]), reqId)
var message: NodesMessage var message: NodesMessage
# TODO: Do the total calculation based on the max UDP packet size we want to
# send and the ENR size of all (max 16) nodes.
# Which UDP packet size to take? 1280? 576?
message.total = ceil(nodes.len / maxNodesPerMessage).uint32 message.total = ceil(nodes.len / maxNodesPerMessage).uint32
for i in 0 ..< nodes.len: for i in 0 ..< nodes.len:
message.enrs.add(nodes[i].record) message.enrs.add(nodes[i].record)
if message.enrs.len == 3: # TODO: Uh, what is this? if message.enrs.len == maxNodesPerMessage:
d.sendNodes(toId, toAddr, message, reqId) let res = d.sendNodes(toId, toAddr, message, reqId)
if res.isErr: # TODO: is there something nicer for this?
return res
message.enrs.setLen(0) message.enrs.setLen(0)
if message.enrs.len != 0: if message.enrs.len != 0:
d.sendNodes(toId, toAddr, message, reqId) let res = d.sendNodes(toId, toAddr, message, reqId)
if res.isErr: # TODO: is there something nicer for this?
return res
ok()
proc handlePing(d: Protocol, fromId: NodeId, fromAddr: Address, proc handlePing(d: Protocol, fromId: NodeId, fromAddr: Address,
ping: PingMessage, reqId: RequestId) = ping: PingMessage, reqId: RequestId): DiscResult[void] =
let a = fromAddr let a = fromAddr
var pong: PongMessage var pong: PongMessage
pong.enrSeq = ping.enrSeq pong.enrSeq = ping.enrSeq
pong.ip = case a.ip.family pong.ip = case a.ip.family
of IpAddressFamily.IPv4: @(a.ip.address_v4) of IpAddressFamily.IPv4: @(a.ip.address_v4)
of IpAddressFamily.IPv6: @(a.ip.address_v6) of IpAddressFamily.IPv6: @(a.ip.address_v6)
pong.port = a.udpPort.uint16 pong.port = a.port.uint16
let (data, _) = ? d.codec.encodePacket(fromId, fromAddr,
encodeMessage(pong, reqId), challenge = nil)
let (data, _) = d.codec.encodePacket(fromId, fromAddr,
encodeMessage(pong, reqId), challenge = nil).tryGet()
d.send(fromAddr, data) d.send(fromAddr, data)
ok()
proc handleFindNode(d: Protocol, fromId: NodeId, fromAddr: Address, proc handleFindNode(d: Protocol, fromId: NodeId, fromAddr: Address,
fn: FindNodeMessage, reqId: RequestId) = fn: FindNodeMessage, reqId: RequestId): DiscResult[void] =
if fn.distance == 0: if fn.distance == 0:
d.sendNodes(fromId, fromAddr, reqId, [d.localNode]) d.sendNodes(fromId, fromAddr, reqId, [d.localNode])
else: else:
@ -253,12 +294,9 @@ proc handleFindNode(d: Protocol, fromId: NodeId, fromAddr: Address,
proc receive*(d: Protocol, a: Address, packet: openArray[byte]) {.gcsafe, proc receive*(d: Protocol, a: Address, packet: openArray[byte]) {.gcsafe,
raises: [ raises: [
Defect, Defect,
# TODO This is now coming from Chronos's callSoon # This just comes now from a future.complete() and `sendWhoareyou` which
Exception, # has it because of `sleepAsync` with `addCallback`
# TODO All of these should probably be handled here Exception
RlpError,
IOError,
TransportAddressError,
].} = ].} =
if packet.len < tagSize: # or magicSize, can be either if packet.len < tagSize: # or magicSize, can be either
return # Invalid packet return # Invalid packet
@ -267,18 +305,29 @@ proc receive*(d: Protocol, a: Address, packet: openArray[byte]) {.gcsafe,
if d.isWhoAreYou(packet): if d.isWhoAreYou(packet):
trace "Received whoareyou", localNode = $d.localNode, address = a trace "Received whoareyou", localNode = $d.localNode, address = a
let whoareyou = d.decodeWhoAreYou(packet) var whoareyou: WhoAreYou
try:
whoareyou = d.decodeWhoAreYou(packet)
except RlpError:
debug "Invalid WhoAreYou packet, decoding failed"
return
var pr: PendingRequest var pr: PendingRequest
if d.pendingRequests.take(whoareyou.authTag, pr): if d.pendingRequests.take(whoareyou.authTag, pr):
let toNode = pr.node let toNode = pr.node
whoareyou.pubKey = toNode.node.pubkey # TODO: Yeah, rather ugly this. whoareyou.pubKey = toNode.pubkey # TODO: Yeah, rather ugly this.
try: doAssert(toNode.address.isSome())
let (data, _) = d.codec.encodePacket(toNode.id, toNode.address, let encoded = d.codec.encodePacket(toNode.id, toNode.address.get(),
pr.message, challenge = whoareyou).tryGet() pr.message, challenge = whoareyou)
# TODO: Perhaps just expect here? Or raise Defect in `encodePacket`?
# if this occurs there is an issue with the system anyhow?
if encoded.isErr:
warn "Not enough randomness to encode packet"
return
let (data, _) = encoded[]
d.send(toNode, data) d.send(toNode, data)
except RandomSourceDepleted: else:
debug "Failed to respond to a who-you-are packet " & debug "Timed out or unrequested WhoAreYou packet"
"due to randomness source depletion."
else: else:
var tag: array[tagSize, byte] var tag: array[tagSize, byte]
@ -293,56 +342,87 @@ proc receive*(d: Protocol, a: Address, packet: openArray[byte]) {.gcsafe,
let message = decoded[] let message = decoded[]
if not node.isNil: if not node.isNil:
# Not filling table with nodes without correct IP in the ENR # Not filling table with nodes without correct IP in the ENR
if a.ip == node.address.ip: # TODO: Should we care about this???
if node.address.isSome() and a == node.address.get():
debug "Adding new node to routing table", node = $node, debug "Adding new node to routing table", node = $node,
localNode = $d.localNode localNode = $d.localNode
discard d.routingTable.addNode(node) discard d.addNode(node)
case message.kind case message.kind
of ping: of ping:
d.handlePing(sender, a, message.ping, message.reqId) if d.handlePing(sender, a, message.ping, message.reqId).isErr:
debug "Sending Pong message failed"
of findNode: of findNode:
d.handleFindNode(sender, a, message.findNode, message.reqId) if d.handleFindNode(sender, a, message.findNode, message.reqId).isErr:
debug "Sending Nodes message failed"
else: else:
var waiter: Future[Option[Message]] var waiter: Future[Option[Message]]
if d.awaitedMessages.take((sender, message.reqId), waiter): if d.awaitedMessages.take((sender, message.reqId), waiter):
waiter.complete(some(message)) waiter.complete(some(message)) # TODO: raises: [Exception]
else: else:
trace "Timed out or unrequested message", message = message.kind, trace "Timed out or unrequested message", message = message.kind,
origin = a origin = a
elif decoded.error == DecodeError.DecryptError: elif decoded.error == DecodeError.DecryptError:
debug "Could not decrypt packet, respond with whoareyou", trace "Could not decrypt packet, respond with whoareyou",
localNode = $d.localNode, address = a localNode = $d.localNode, address = a
# only sendingWhoareyou in case it is a decryption failure # only sendingWhoareyou in case it is a decryption failure
d.sendWhoareyou(a, sender, authTag) let res = d.sendWhoareyou(a, sender, authTag)
if res.isErr():
trace "Sending WhoAreYou packet failed", err = res.error
elif decoded.error == DecodeError.UnsupportedMessage: elif decoded.error == DecodeError.UnsupportedMessage:
# Still adding the node in case failure is because of unsupported message. # Still adding the node in case failure is because of unsupported message.
if not node.isNil: if not node.isNil:
if a.ip == node.address.ip: # Not filling table with nodes without correct IP in the ENR
# TODO: Should we care about this???s
if node.address.isSome() and a == node.address.get():
debug "Adding new node to routing table", node = $node, debug "Adding new node to routing table", node = $node,
localNode = $d.localNode localNode = $d.localNode
discard d.routingTable.addNode(node) discard d.addNode(node)
# elif decoded.error == DecodeError.PacketError: # elif decoded.error == DecodeError.PacketError:
# Not adding this node as from our perspective it is sending rubbish. # Not adding this node as from our perspective it is sending rubbish.
proc processClient(transp: DatagramTransport, # TODO: Not sure why but need to pop the raises here as it is apparently not
raddr: TransportAddress): Future[void] {.async, gcsafe.} = # enough to put it in the raises pragma of `processClient` and other async procs.
var proto = getUserData[Protocol](transp) {.pop.}
try: # Next, below there is no more effort done in catching the general `Exception`
# TODO: Maybe here better to use `peekMessage()` to avoid allocation, # as async procs always require `Exception` in the raises pragma, see also:
# but `Bytes` object is just a simple seq[byte], and `ByteRange` object # https://github.com/status-im/nim-chronos/issues/98
# do not support custom length. # So I don't bother for now and just add them in the raises pragma until this
var buf = transp.getMessage() # gets fixed.
let a = Address(ip: raddr.address, udpPort: raddr.port, tcpPort: raddr.port) proc processClient(transp: DatagramTransport, raddr: TransportAddress):
proto.receive(a, buf) Future[void] {.async, gcsafe, raises: [Exception, Defect].} =
except RlpError as e: let proto = getUserData[Protocol](transp)
debug "Receive failed", exception = e.name, msg = e.msg var a: Address
# TODO: what else can be raised? Figure this out and be more restrictive? var buf = newSeq[byte]()
except CatchableError as e:
debug "Receive failed", exception = e.name, msg = e.msg,
stacktrace = e.getStackTrace()
proc validIp(sender, address: IpAddress): bool = try:
a = Address(ip: raddr.address, port: raddr.port)
except ValueError:
# This should not be possible considering we bind to an IP address.
error "Not a valid IpAddress"
return
try:
# TODO: should we use `peekMessage()` to avoid allocation?
# TODO: This can still raise general `Exception` while it probably should
# only give TransportOsError.
buf = transp.getMessage()
except TransportOsError as e:
# This is likely to be local network connection issues.
error "Transport getMessage error", exception = e.name, msg = e.msg
except Exception as e:
if e of Defect:
raise (ref Defect)(e)
else: doAssert(false)
try:
proto.receive(a, buf)
except Exception as e:
if e of Defect:
raise (ref Defect)(e)
else: doAssert(false)
proc validIp(sender, address: IpAddress): bool {.raises: [Defect].} =
let let
s = initTAddress(sender, Port(0)) s = initTAddress(sender, Port(0))
a = initTAddress(address, Port(0)) a = initTAddress(address, Port(0))
@ -362,74 +442,95 @@ proc validIp(sender, address: IpAddress): bool =
# TODO: This could be improved to do the clean-up immediatily in case a non # TODO: This could be improved to do the clean-up immediatily in case a non
# whoareyou response does arrive, but we would need to store the AuthTag # whoareyou response does arrive, but we would need to store the AuthTag
# somewhere # somewhere
proc registerRequest(d: Protocol, n: Node, message: seq[byte], nonce: AuthTag) = proc registerRequest(d: Protocol, n: Node, message: seq[byte], nonce: AuthTag)
{.raises: [Exception, Defect].} =
let request = PendingRequest(node: n, message: message) let request = PendingRequest(node: n, message: message)
if not d.pendingRequests.hasKeyOrPut(nonce, request): if not d.pendingRequests.hasKeyOrPut(nonce, request):
# TODO: raises: [Exception]
sleepAsync(responseTimeout).addCallback() do(data: pointer): sleepAsync(responseTimeout).addCallback() do(data: pointer):
d.pendingRequests.del(nonce) d.pendingRequests.del(nonce)
proc waitMessage(d: Protocol, fromNode: Node, reqId: RequestId): Future[Option[Message]] = proc waitMessage(d: Protocol, fromNode: Node, reqId: RequestId):
Future[Option[Message]] {.raises: [Exception, Defect].} =
result = newFuture[Option[Message]]("waitMessage") result = newFuture[Option[Message]]("waitMessage")
let res = result let res = result
let key = (fromNode.id, reqId) let key = (fromNode.id, reqId)
# TODO: raises: [Exception]
sleepAsync(responseTimeout).addCallback() do(data: pointer): sleepAsync(responseTimeout).addCallback() do(data: pointer):
d.awaitedMessages.del(key) d.awaitedMessages.del(key)
if not res.finished: if not res.finished:
res.complete(none(Message)) res.complete(none(Message)) # TODO: raises: [Exception]
d.awaitedMessages[key] = result d.awaitedMessages[key] = result
proc addNodesFromENRs(result: var seq[Node], enrs: openarray[Record]) = proc addNodesFromENRs(result: var seq[Node], enrs: openarray[Record])
for r in enrs: result.add(newNode(r)) {.raises: [Defect].} =
for r in enrs:
let node = newNode(r)
if node.isOk():
result.add(node[])
proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): Future[seq[Node]] {.async.} = proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId):
Future[DiscResult[seq[Node]]] {.async, raises: [Exception, Defect].} =
var op = await d.waitMessage(fromNode, reqId) var op = await d.waitMessage(fromNode, reqId)
if op.isSome and op.get.kind == nodes: if op.isSome and op.get.kind == nodes:
result.addNodesFromENRs(op.get.nodes.enrs) var res = newSeq[Node]()
res.addNodesFromENRs(op.get.nodes.enrs)
let total = op.get.nodes.total let total = op.get.nodes.total
for i in 1 ..< total: for i in 1 ..< total:
op = await d.waitMessage(fromNode, reqId) op = await d.waitMessage(fromNode, reqId)
if op.isSome and op.get.kind == nodes: if op.isSome and op.get.kind == nodes:
result.addNodesFromENRs(op.get.nodes.enrs) res.addNodesFromENRs(op.get.nodes.enrs)
else: else:
break break
return ok(res)
else:
return err("Nodes message not received in time")
proc sendPing(d: Protocol, toNode: Node): RequestId =
proc sendMessage*[T: SomeMessage](d: Protocol, toNode: Node, m: T):
DiscResult[RequestId] {.raises: [Exception, Defect].} =
doAssert(toNode.address.isSome())
let let
reqId = newRequestId().tryGet() reqId = ? newRequestId()
ping = PingMessage(enrSeq: d.localNode.record.seqNum) message = encodeMessage(m, reqId)
message = encodeMessage(ping, reqId) (data, nonce) = ? d.codec.encodePacket(toNode.id, toNode.address.get(),
(data, nonce) = d.codec.encodePacket(toNode.id, toNode.address, message, message, challenge = nil)
challenge = nil).tryGet()
d.registerRequest(toNode, message, nonce) d.registerRequest(toNode, message, nonce)
d.send(toNode, data) d.send(toNode, data)
return reqId return ok(reqId)
proc ping*(d: Protocol, toNode: Node): Future[Option[PongMessage]] {.async.} = proc ping*(d: Protocol, toNode: Node):
let reqId = d.sendPing(toNode) Future[DiscResult[PongMessage]] {.async, raises: [Exception, Defect].} =
let resp = await d.waitMessage(toNode, reqId) let reqId = d.sendMessage(toNode,
PingMessage(enrSeq: d.localNode.record.seqNum))
if reqId.isErr:
return err(reqId.error)
let resp = await d.waitMessage(toNode, reqId[])
if resp.isSome() and resp.get().kind == pong: if resp.isSome() and resp.get().kind == pong:
return some(resp.get().pong) return ok(resp.get().pong)
else:
return err("Pong message not received in time")
proc sendFindNode(d: Protocol, toNode: Node, distance: uint32): RequestId = proc findNode*(d: Protocol, toNode: Node, distance: uint32):
let reqId = newRequestId().tryGet() Future[DiscResult[seq[Node]]] {.async, raises: [Exception, Defect].} =
let message = encodeMessage(FindNodeMessage(distance: distance), reqId) let reqId = d.sendMessage(toNode, FindNodeMessage(distance: distance))
let (data, nonce) = d.codec.encodePacket(toNode.id, toNode.address, message, if reqId.isErr:
challenge = nil).tryGet() return err(reqId.error)
d.registerRequest(toNode, message, nonce) let nodes = await d.waitNodes(toNode, reqId[])
d.send(toNode, data) if nodes.isOk:
return reqId var res = newSeq[Node]()
for n in nodes[]:
if n.address.isSome() and
validIp(toNode.address.get().ip, n.address.get().ip):
res.add(n)
# TODO: Check ports
return ok(res)
else:
return err(nodes.error)
proc findNode*(d: Protocol, toNode: Node, distance: uint32): Future[seq[Node]] {.async.} = proc lookupDistances(target, dest: NodeId): seq[uint32] {.raises: [Defect].} =
let reqId = sendFindNode(d, toNode, distance)
let nodes = await d.waitNodes(toNode, reqId)
for n in nodes:
if validIp(toNode.address.ip, n.address.ip):
result.add(n)
proc lookupDistances(target, dest: NodeId): seq[uint32] =
let td = logDist(target, dest) let td = logDist(target, dest)
result.add(td) result.add(td)
var i = 1'u32 var i = 1'u32
@ -440,20 +541,23 @@ proc lookupDistances(target, dest: NodeId): seq[uint32] =
result.add(td - i) result.add(td - i)
inc i inc i
proc lookupWorker(d: Protocol, destNode: Node, target: NodeId): Future[seq[Node]] {.async.} = proc lookupWorker(d: Protocol, destNode: Node, target: NodeId):
Future[seq[Node]] {.async, raises: [Exception, Defect].} =
let dists = lookupDistances(target, destNode.id) let dists = lookupDistances(target, destNode.id)
var i = 0 var i = 0
while i < lookupRequestLimit and result.len < findNodeResultLimit: while i < lookupRequestLimit and result.len < findNodeResultLimit:
# TODO: Handle failures
let r = await d.findNode(destNode, dists[i]) let r = await d.findNode(destNode, dists[i])
# TODO: Handle failures better. E.g. stop on different failures than timeout
if r.isOk:
# TODO: I guess it makes sense to limit here also to `findNodeResultLimit`? # TODO: I guess it makes sense to limit here also to `findNodeResultLimit`?
result.add(r) result.add(r[])
inc i inc i
for n in result: for n in result:
discard d.routingTable.addNode(n) discard d.routingTable.addNode(n)
proc lookup*(d: Protocol, target: NodeId): Future[seq[Node]] {.async.} = proc lookup*(d: Protocol, target: NodeId): Future[seq[Node]]
{.async, raises: [Exception, Defect].} =
## Perform a lookup for the given target, return the closest n nodes to the ## Perform a lookup for the given target, return the closest n nodes to the
## target. Maximum value for n is `BUCKET_SIZE`. ## target. Maximum value for n is `BUCKET_SIZE`.
# TODO: Sort the returned nodes on distance # TODO: Sort the returned nodes on distance
@ -489,14 +593,16 @@ proc lookup*(d: Protocol, target: NodeId): Future[seq[Node]] {.async.} =
if result.len < BUCKET_SIZE: if result.len < BUCKET_SIZE:
result.add(n) result.add(n)
proc lookupRandom*(d: Protocol): Future[seq[Node]] proc lookupRandom*(d: Protocol): Future[DiscResult[seq[Node]]]
{.raises:[RandomSourceDepleted, Defect, Exception].} = {.async, raises:[Exception, Defect].} =
var id: NodeId var id: NodeId
if randomBytes(addr id, sizeof(id)) != sizeof(id): if randomBytes(addr id, sizeof(id)) != sizeof(id):
raise newException(RandomSourceDepleted, "Could not randomize bytes") return err("Could not randomize bytes")
d.lookup(id)
proc resolve*(d: Protocol, id: NodeId): Future[Option[Node]] {.async.} = return ok(await d.lookup(id))
proc resolve*(d: Protocol, id: NodeId): Future[Option[Node]]
{.async, raises: [Exception, Defect].} =
## Resolve a `Node` based on provided `NodeId`. ## Resolve a `Node` based on provided `NodeId`.
## ##
## This will first look in the own DHT. If the node is known, it will try to ## This will first look in the own DHT. If the node is known, it will try to
@ -508,8 +614,9 @@ proc resolve*(d: Protocol, id: NodeId): Future[Option[Node]] {.async.} =
if node.isSome(): if node.isSome():
let request = await d.findNode(node.get(), 0) let request = await d.findNode(node.get(), 0)
if request.len > 0: # TODO: Handle failures better. E.g. stop on different failures than timeout
return some(request[0]) if request.isOk() and request[].len > 0:
return some(request[][0])
let discovered = await d.lookup(id) let discovered = await d.lookup(id)
for n in discovered: for n in discovered:
@ -522,11 +629,11 @@ proc resolve*(d: Protocol, id: NodeId): Future[Option[Node]] {.async.} =
return some(n) return some(n)
proc revalidateNode*(d: Protocol, n: Node) proc revalidateNode*(d: Protocol, n: Node)
{.async, raises:[Defect, Exception].} = # TODO: Exception {.async, raises: [Exception, Defect].} = # TODO: Exception
trace "Ping to revalidate node", node = $n trace "Ping to revalidate node", node = $n
let pong = await d.ping(n) let pong = await d.ping(n)
if pong.isSome(): if pong.isOK():
if pong.get().enrSeq > n.record.seqNum: if pong.get().enrSeq > n.record.seqNum:
# TODO: Request new ENR # TODO: Request new ENR
discard discard
@ -534,6 +641,8 @@ proc revalidateNode*(d: Protocol, n: Node)
d.routingTable.setJustSeen(n) d.routingTable.setJustSeen(n)
trace "Revalidated node", node = $n trace "Revalidated node", node = $n
else: else:
# TODO: Handle failures better. E.g. don't remove nodes on different
# failures than timeout
# For now we never remove bootstrap nodes. It might make sense to actually # For now we never remove bootstrap nodes. It might make sense to actually
# do so and to retry them only in case we drop to a really low amount of # do so and to retry them only in case we drop to a really low amount of
# peers in the DHT # peers in the DHT
@ -544,15 +653,14 @@ proc revalidateNode*(d: Protocol, n: Node)
# This might be to direct, so we could keep these longer. But better # This might be to direct, so we could keep these longer. But better
# would be to simply not remove the nodes immediatly but only after x # would be to simply not remove the nodes immediatly but only after x
# amount of failures. # amount of failures.
discard d.codec.db.deleteKeys(n.id, n.address) doAssert(n.address.isSome())
discard d.codec.db.deleteKeys(n.id, n.address.get())
else: else:
debug "Revalidation of bootstrap node failed", enr = toURI(n.record) debug "Revalidation of bootstrap node failed", enr = toURI(n.record)
proc revalidateLoop(d: Protocol) {.async.} = proc revalidateLoop(d: Protocol) {.async, raises: [Exception, Defect].} =
# TODO: General Exception raised.
try: try:
# TODO: We need to handle actual errors still, which might just allow to
# continue the loop. However, currently `revalidateNode` raises a general
# `Exception` making this rather hard.
while true: while true:
await sleepAsync(rand(10 * 1000).milliseconds) await sleepAsync(rand(10 * 1000).milliseconds)
let n = d.routingTable.nodeToRevalidate() let n = d.routingTable.nodeToRevalidate()
@ -563,16 +671,20 @@ proc revalidateLoop(d: Protocol) {.async.} =
except CancelledError: except CancelledError:
trace "revalidateLoop canceled" trace "revalidateLoop canceled"
proc lookupLoop(d: Protocol) {.async.} = proc lookupLoop(d: Protocol) {.async, raises: [Exception, Defect].} =
## TODO: Same story as for `revalidateLoop` # TODO: General Exception raised.
try: try:
while true: while true:
# lookup self (neighbour nodes) # lookup self (neighbour nodes)
var nodes = await d.lookup(d.localNode.id) let selfLookup = await d.lookup(d.localNode.id)
trace "Discovered nodes in self lookup", nodes = $nodes trace "Discovered nodes in self lookup", nodes = $selfLookup
nodes = await d.lookupRandom() let randomLookup = await d.lookupRandom()
trace "Discovered nodes in random lookup", nodes = $nodes if randomLookup.isOK:
trace "Discovered nodes in random lookup", nodes = $randomLookup[]
trace "Total nodes in routing table", total = d.routingTable.len()
else:
trace "random lookup failed", err = randomLookup.error
await sleepAsync(lookupInterval) await sleepAsync(lookupInterval)
except CancelledError: except CancelledError:
trace "lookupLoop canceled" trace "lookupLoop canceled"
@ -580,14 +692,12 @@ proc lookupLoop(d: Protocol) {.async.} =
proc newProtocol*(privKey: PrivateKey, db: Database, proc newProtocol*(privKey: PrivateKey, db: Database,
externalIp: Option[IpAddress], tcpPort, udpPort: Port, externalIp: Option[IpAddress], tcpPort, udpPort: Port,
localEnrFields: openarray[FieldPair] = [], localEnrFields: openarray[FieldPair] = [],
bootstrapRecords: openarray[Record] = []): Protocol = bootstrapRecords: openarray[Record] = []):
Protocol {.raises: [Defect].} =
let let
a = Address(ip: externalIp.get(IPv4_any()),
tcpPort: tcpPort, udpPort: udpPort)
enode = ENode(pubkey: privKey.toPublicKey().tryGet(), address: a)
enrRec = enr.Record.init(1, privKey, externalIp, tcpPort, udpPort, enrRec = enr.Record.init(1, privKey, externalIp, tcpPort, udpPort,
localEnrFields).expect("Properly intialized private key") localEnrFields).expect("Properly intialized private key")
node = newNode(enode, enrRec) node = newNode(enrRec).expect("Properly initialized node")
result = Protocol( result = Protocol(
privateKey: privKey, privateKey: privKey,
@ -600,23 +710,24 @@ proc newProtocol*(privKey: PrivateKey, db: Database,
result.routingTable.init(node) result.routingTable.init(node)
proc open*(d: Protocol) = proc open*(d: Protocol) {.raises: [Exception, Defect].} =
info "Starting discovery node", node = $d.localNode, info "Starting discovery node", node = $d.localNode,
uri = toURI(d.localNode.record) uri = toURI(d.localNode.record)
# TODO allow binding to specific IP / IPv6 / etc # TODO allow binding to specific IP / IPv6 / etc
let ta = initTAddress(IPv4_any(), d.localNode.node.address.udpPort) let ta = initTAddress(IPv4_any(), Port(d.localNode.address.get().port))
# TODO: raises `OSError` and `IOSelectorsException`, the latter which is
# object of Exception. In Nim devel this got changed to CatchableError.
d.transp = newDatagramTransport(processClient, udata = d, local = ta) d.transp = newDatagramTransport(processClient, udata = d, local = ta)
for record in d.bootstrapRecords: for record in d.bootstrapRecords:
debug "Adding bootstrap node", uri = toURI(record) debug "Adding bootstrap node", uri = toURI(record)
d.addNode(record) discard d.addNode(record)
proc start*(d: Protocol) = proc start*(d: Protocol) {.raises: [Exception, Defect].} =
# Might want to move these to a separate proc if this turns out to be needed.
d.lookupLoop = lookupLoop(d) d.lookupLoop = lookupLoop(d)
d.revalidateLoop = revalidateLoop(d) d.revalidateLoop = revalidateLoop(d)
proc close*(d: Protocol) = proc close*(d: Protocol) {.raises: [Exception, Defect].} =
doAssert(not d.transp.closed) doAssert(not d.transp.closed)
debug "Closing discovery node", node = $d.localNode debug "Closing discovery node", node = $d.localNode
@ -624,11 +735,10 @@ proc close*(d: Protocol) =
d.revalidateLoop.cancel() d.revalidateLoop.cancel()
if not d.lookupLoop.isNil: if not d.lookupLoop.isNil:
d.lookupLoop.cancel() d.lookupLoop.cancel()
# TODO: unsure if close can't create issues in the not awaited cancellations
# above
d.transp.close() d.transp.close()
proc closeWait*(d: Protocol) {.async.} = proc closeWait*(d: Protocol) {.async, raises: [Exception, Defect].} =
doAssert(not d.transp.closed) doAssert(not d.transp.closed)
debug "Closing discovery node", node = $d.localNode debug "Closing discovery node", node = $d.localNode

View File

@ -1,7 +1,7 @@
import import
std/[algorithm, times, sequtils, bitops, random, sets, options], std/[algorithm, times, sequtils, bitops, random, sets, options],
stint, chronicles, stint, chronicles,
types, node node
{.push raises: [Defect].} {.push raises: [Defect].}

View File

@ -1,6 +1,8 @@
import import
hashes, stint, hashes, stint, chronos,
eth/[keys, rlp], ../enode, enr eth/[keys, rlp], enr, node
{.push raises: [Defect].}
const const
authTagSize* = 12 authTagSize* = 12
@ -8,7 +10,6 @@ const
aesKeySize* = 128 div 8 aesKeySize* = 128 div 8
type type
NodeId* = UInt256
AuthTag* = array[authTagSize, byte] AuthTag* = array[authTagSize, byte]
IdNonce* = array[idNonceSize, byte] IdNonce* = array[idNonceSize, byte]
AesKey* = array[aesKeySize, byte] AesKey* = array[aesKeySize, byte]
@ -82,14 +83,14 @@ template messageKind*(T: typedesc[SomeMessage]): MessageKind =
elif T is FindNodeMessage: findNode elif T is FindNodeMessage: findNode
elif T is NodesMessage: nodes elif T is NodesMessage: nodes
method storeKeys*(db: Database, id: NodeId, address: Address, r, w: AesKey): method storeKeys*(db: Database, id: NodeId, address: Address,
bool {.base, raises: [Defect].} = discard r, w: AesKey): bool {.base.} = discard
method loadKeys*(db: Database, id: NodeId, address: Address, r, w: var AesKey): method loadKeys*(db: Database, id: NodeId, address: Address,
bool {.base, raises: [Defect].} = discard r, w: var AesKey): bool {.base.} = discard
method deleteKeys*(db: Database, id: NodeId, address: Address): method deleteKeys*(db: Database, id: NodeId, address: Address):
bool {.raises: [Defect].} = discard bool {.base.} = discard
proc toBytes*(id: NodeId): array[32, byte] {.inline.} = proc toBytes*(id: NodeId): array[32, byte] {.inline.} =
id.toByteArrayBE() id.toByteArrayBE()

View File

@ -1,17 +1,20 @@
import import
unittest, chronos, sequtils, chronicles, tables, stint, nimcrypto, unittest, chronos, sequtils, chronicles, tables, stint, nimcrypto,
eth/[keys, rlp], eth/p2p/enode, eth/trie/db, eth/[keys, rlp], eth/trie/db,
eth/p2p/discoveryv5/[discovery_db, enr, node, types, routing_table, encoding], eth/p2p/discoveryv5/[discovery_db, enr, node, types, routing_table, encoding],
eth/p2p/discoveryv5/protocol as discv5_protocol, eth/p2p/discoveryv5/protocol as discv5_protocol,
./p2p_test_helper ./p2p_test_helper
proc localAddress*(port: int): Address =
Address(ip: parseIpAddress("127.0.0.1"), port: Port(port))
proc initDiscoveryNode*(privKey: PrivateKey, address: Address, proc initDiscoveryNode*(privKey: PrivateKey, address: Address,
bootstrapRecords: openarray[Record] = []): bootstrapRecords: openarray[Record] = []):
discv5_protocol.Protocol = discv5_protocol.Protocol =
var db = DiscoveryDB.init(newMemoryDB()) var db = DiscoveryDB.init(newMemoryDB())
result = newProtocol(privKey, db, result = newProtocol(privKey, db,
some(parseIpAddress("127.0.0.1")), some(address.ip),
address.tcpPort, address.udpPort, address.port, address.port,
bootstrapRecords = bootstrapRecords) bootstrapRecords = bootstrapRecords)
result.open() result.open()
@ -26,8 +29,8 @@ proc randomPacket(tag: PacketTag): seq[byte] =
authTag: AuthTag authTag: AuthTag
msg: array[44, byte] msg: array[44, byte]
require randomBytes(authTag) == authTag.len check randomBytes(authTag) == authTag.len
require randomBytes(msg) == msg.len check randomBytes(msg) == msg.len
result.add(tag) result.add(tag)
result.add(rlp.encode(authTag)) result.add(rlp.encode(authTag))
result.add(msg) result.add(msg)
@ -36,7 +39,7 @@ proc generateNode(privKey = PrivateKey.random()[], port: int = 20302): Node =
let port = Port(port) let port = Port(port)
let enr = enr.Record.init(1, privKey, some(parseIpAddress("127.0.0.1")), let enr = enr.Record.init(1, privKey, some(parseIpAddress("127.0.0.1")),
port, port).expect("Properly intialized private key") port, port).expect("Properly intialized private key")
result = newNode(enr) result = newNode(enr).expect("Properly initialized node")
proc nodeAtDistance(n: Node, d: uint32): Node = proc nodeAtDistance(n: Node, d: uint32): Node =
while true: while true:
@ -55,13 +58,13 @@ suite "Discovery v5 Tests":
node = initDiscoveryNode(PrivateKey.random()[], localAddress(20302)) node = initDiscoveryNode(PrivateKey.random()[], localAddress(20302))
targetNode = generateNode() targetNode = generateNode()
node.addNode(targetNode) check node.addNode(targetNode)
for i in 0..<1000: for i in 0..<1000:
node.addNode(generateNode()) discard node.addNode(generateNode())
let n = node.getNode(targetNode.id) let n = node.getNode(targetNode.id)
require n.isSome() check n.isSome()
check n.get() == targetNode check n.get() == targetNode
await node.closeWait() await node.closeWait()
@ -76,7 +79,7 @@ suite "Discovery v5 Tests":
pong1 = await discv5_protocol.ping(node1, bootnode.localNode) pong1 = await discv5_protocol.ping(node1, bootnode.localNode)
pong2 = await discv5_protocol.ping(node1, node2.localNode) pong2 = await discv5_protocol.ping(node1, node2.localNode)
check pong1.isSome() and pong2.isSome() check pong1.isOk() and pong2.isOk()
await bootnode.closeWait() await bootnode.closeWait()
await node2.closeWait() await node2.closeWait()
@ -85,9 +88,10 @@ suite "Discovery v5 Tests":
await node1.revalidateNode(node2.localNode) await node1.revalidateNode(node2.localNode)
let n = node1.getNode(bootnode.localNode.id) let n = node1.getNode(bootnode.localNode.id)
require n.isSome() check:
check n.get() == bootnode.localNode n.isSome()
check node1.getNode(node2.localNode.id).isNone() n.get() == bootnode.localNode
node1.getNode(node2.localNode.id).isNone()
await node1.closeWait() await node1.closeWait()
@ -98,7 +102,7 @@ suite "Discovery v5 Tests":
let a = localAddress(20303) let a = localAddress(20303)
for i in 0 ..< 5: for i in 0 ..< 5:
require randomBytes(tag) == tag.len check randomBytes(tag) == tag.len
node.receive(a, randomPacket(tag)) node.receive(a, randomPacket(tag))
# Checking different nodeIds but same address # Checking different nodeIds but same address
@ -225,42 +229,47 @@ suite "Discovery v5 Tests":
let nodes = nodesAtDistance(mainNode.localNode, dist, 10) let nodes = nodesAtDistance(mainNode.localNode, dist, 10)
for n in nodes: for n in nodes:
mainNode.addNode(n) discard mainNode.addNode(n)
# Get ENR of the node itself # Get ENR of the node itself
var discovered = var discovered =
await discv5_protocol.findNode(testNode, mainNode.localNode, 0) await discv5_protocol.findNode(testNode, mainNode.localNode, 0)
check: check:
discovered.len == 1 discovered.isOk
discovered[0] == mainNode.localNode discovered[].len == 1
discovered[][0] == mainNode.localNode
# Get ENRs of nodes added at provided logarithmic distance # Get ENRs of nodes added at provided logarithmic distance
discovered = discovered =
await discv5_protocol.findNode(testNode, mainNode.localNode, dist) await discv5_protocol.findNode(testNode, mainNode.localNode, dist)
check discovered.len == 10 check discovered.isOk
check discovered[].len == 10
for n in nodes: for n in nodes:
check discovered.contains(n) check discovered[].contains(n)
# Too high logarithmic distance, caps at 256 # Too high logarithmic distance, caps at 256
discovered = discovered =
await discv5_protocol.findNode(testNode, mainNode.localNode, 4294967295'u32) await discv5_protocol.findNode(testNode, mainNode.localNode, 4294967295'u32)
check: check:
discovered.len == 1 discovered.isOk
discovered[0] == testNode.localNode discovered[].len == 1
discovered[][0] == testNode.localNode
# Empty bucket # Empty bucket
discovered = discovered =
await discv5_protocol.findNode(testNode, mainNode.localNode, 254) await discv5_protocol.findNode(testNode, mainNode.localNode, 254)
check discovered.len == 0 check discovered.isOk
check discovered[].len == 0
let moreNodes = nodesAtDistance(mainNode.localNode, dist, 10) let moreNodes = nodesAtDistance(mainNode.localNode, dist, 10)
for n in moreNodes: for n in moreNodes:
mainNode.addNode(n) discard mainNode.addNode(n)
# Full bucket # Full bucket
discovered = discovered =
await discv5_protocol.findNode(testNode, mainNode.localNode, dist) await discv5_protocol.findNode(testNode, mainNode.localNode, dist)
check discovered.len == 16 check discovered.isOk
check discovered[].len == 16
await mainNode.closeWait() await mainNode.closeWait()
await testNode.closeWait() await testNode.closeWait()
@ -271,7 +280,7 @@ suite "Discovery v5 Tests":
# Generate 1000 random nodes and add to our main node's routing table # Generate 1000 random nodes and add to our main node's routing table
for i in 0..<1000: for i in 0..<1000:
mainNode.addNode(generateNode()) discard mainNode.addNode(generateNode())
let let
neighbours = mainNode.neighbours(mainNode.localNode.id) neighbours = mainNode.neighbours(mainNode.localNode.id)
@ -286,7 +295,8 @@ suite "Discovery v5 Tests":
discovered = await discv5_protocol.findNode(testNode, mainNode.localNode, discovered = await discv5_protocol.findNode(testNode, mainNode.localNode,
closestDistance) closestDistance)
check closest in discovered check discovered.isOk
check closest in discovered[]
await mainNode.closeWait() await mainNode.closeWait()
await testNode.closeWait() await testNode.closeWait()
@ -330,11 +340,11 @@ suite "Discovery v5 Tests":
# if resolve works (only local lookup) # if resolve works (only local lookup)
block: block:
let pong = await targetNode.ping(mainNode.localNode) let pong = await targetNode.ping(mainNode.localNode)
require pong.isSome() check pong.isOk()
await targetNode.closeWait() await targetNode.closeWait()
let n = await mainNode.resolve(targetId) let n = await mainNode.resolve(targetId)
require n.isSome()
check: check:
n.isSome()
n.get().id == targetId n.get().id == targetId
n.get().record.seqNum == targetSeqNum n.get().record.seqNum == targetSeqNum
@ -344,12 +354,12 @@ suite "Discovery v5 Tests":
# TODO: need to add some logic to update ENRs properly # TODO: need to add some logic to update ENRs properly
targetSeqNum.inc() targetSeqNum.inc()
let r = enr.Record.init(targetSeqNum, targetKey, let r = enr.Record.init(targetSeqNum, targetKey,
some(targetAddress.ip), targetAddress.tcpPort, targetAddress.udpPort)[] some(targetAddress.ip), targetAddress.port, targetAddress.port)[]
targetNode.localNode.record = r targetNode.localNode.record = r
targetNode.open() targetNode.open()
let n = await mainNode.resolve(targetId) let n = await mainNode.resolve(targetId)
require n.isSome()
check: check:
n.isSome()
n.get().id == targetId n.get().id == targetId
n.get().record.seqNum == targetSeqNum n.get().record.seqNum == targetSeqNum
@ -358,20 +368,20 @@ suite "Discovery v5 Tests":
block: block:
targetSeqNum.inc() targetSeqNum.inc()
let r = enr.Record.init(3, targetKey, some(targetAddress.ip), let r = enr.Record.init(3, targetKey, some(targetAddress.ip),
targetAddress.tcpPort, targetAddress.udpPort)[] targetAddress.port, targetAddress.port)[]
targetNode.localNode.record = r targetNode.localNode.record = r
let pong = await targetNode.ping(lookupNode.localNode) let pong = await targetNode.ping(lookupNode.localNode)
require pong.isSome() check pong.isOk()
await targetNode.closeWait() await targetNode.closeWait()
# TODO: This step should eventually not be needed and ENRs with new seqNum # TODO: This step should eventually not be needed and ENRs with new seqNum
# should just get updated in the lookup. # should just get updated in the lookup.
await mainNode.revalidateNode(targetNode.localNode) await mainNode.revalidateNode(targetNode.localNode)
mainNode.addNode(lookupNode.localNode.record) check mainNode.addNode(lookupNode.localNode.record)
let n = await mainNode.resolve(targetId) let n = await mainNode.resolve(targetId)
require n.isSome()
check: check:
n.isSome()
n.get().id == targetId n.get().id == targetId
n.get().record.seqNum == targetSeqNum n.get().record.seqNum == targetSeqNum

View File

@ -169,7 +169,7 @@ suite "Discovery v5 Cryptographic Primitives":
privKey = PrivateKey.fromHex(localSecretKey)[] privKey = PrivateKey.fromHex(localSecretKey)[]
signature = signIDNonce(privKey, hexToByteArray[idNonceSize](idNonce), signature = signIDNonce(privKey, hexToByteArray[idNonceSize](idNonce),
hexToByteArray[64](ephemeralKey)) hexToByteArray[64](ephemeralKey))
require signature.isOK() check signature.isOK()
check signature[].toRaw() == hexToByteArray[64](idNonceSig) check signature[].toRaw() == hexToByteArray[64](idNonceSig)
test "Encryption/Decryption": test "Encryption/Decryption":