diff --git a/eth/p2p/discoveryv5/protocol.nim b/eth/p2p/discoveryv5/protocol.nim index 80ec19d..b1ab8ab 100644 --- a/eth/p2p/discoveryv5/protocol.nim +++ b/eth/p2p/discoveryv5/protocol.nim @@ -70,35 +70,8 @@ proc neighbours*(d: Protocol, id: NodeId, k: int = BUCKET_SIZE): seq[Node] = proc nodesDiscovered*(d: Protocol): int {.inline.} = d.routingTable.len -proc whoareyouMagic(toNode: NodeId): array[magicSize, byte] = - const prefix = "WHOAREYOU" - var data: array[prefix.len + sizeof(toNode), byte] - data[0 .. sizeof(toNode) - 1] = toNode.toByteArrayBE() - for i, c in prefix: data[sizeof(toNode) + i] = byte(c) - sha256.digest(data).data - -proc newProtocol*(privKey: PrivateKey, db: Database, - ip: IpAddress, tcpPort, udpPort: Port, - bootstrapRecords: openarray[Record] = []): Protocol = - let - a = Address(ip: ip, tcpPort: tcpPort, udpPort: udpPort) - enode = initENode(privKey.getPublicKey(), a) - enrRec = enr.Record.init(12, privKey, some(a)) - node = newNode(enode, enrRec) - - result = Protocol( - privateKey: privKey, - db: db, - localNode: node, - whoareyouMagic: whoareyouMagic(node.id), - idHash: sha256.digest(node.id.toByteArrayBE).data, - codec: Codec(localNode: node, privKey: privKey, db: db), - bootstrapNodes: newNodes(bootstrapRecords)) - - result.routingTable.init(node) - -func privKey*(p: Protocol): lent PrivateKey = - p.privateKey +func privKey*(d: Protocol): lent PrivateKey = + d.privateKey proc send(d: Protocol, a: Address, data: seq[byte]) = # debug "Sending bytes", amount = data.len, to = a @@ -115,6 +88,13 @@ proc `xor`[N: static[int], T](a, b: array[N, T]): array[N, T] = for i in 0 .. a.high: result[i] = a[i] xor b[i] +proc whoareyouMagic(toNode: NodeId): array[magicSize, byte] = + const prefix = "WHOAREYOU" + var data: array[prefix.len + sizeof(toNode), byte] + data[0 .. sizeof(toNode) - 1] = toNode.toByteArrayBE() + for i, c in prefix: data[sizeof(toNode) + i] = byte(c) + sha256.digest(data).data + proc isWhoAreYou(d: Protocol, msg: Bytes): bool = if msg.len > d.whoareyouMagic.len: result = d.whoareyouMagic == msg.toOpenArray(0, magicSize - 1) @@ -262,6 +242,32 @@ proc receive*(d: Protocol, a: Address, msg: Bytes) {.gcsafe, debug "Adding new node to routing table", node = $node, localNode = $d.localNode discard d.routingTable.addNode(node) +proc processClient(transp: DatagramTransport, + raddr: TransportAddress): Future[void] {.async, gcsafe.} = + var proto = getUserData[Protocol](transp) + try: + # TODO: Maybe here better to use `peekMessage()` to avoid allocation, + # but `Bytes` object is just a simple seq[byte], and `ByteRange` object + # do not support custom length. + var buf = transp.getMessage() + let a = Address(ip: raddr.address, udpPort: raddr.port, tcpPort: raddr.port) + proto.receive(a, buf) + except RlpError as e: + debug "Receive failed", exception = e.name, msg = e.msg + # TODO: what else can be raised? Figure this out and be more restrictive? + except CatchableError as e: + debug "Receive failed", exception = e.name, msg = e.msg, + stacktrace = e.getStackTrace() + +# TODO: This could be improved to do the clean-up immediatily in case a non +# whoareyou response does arrive, but we would need to store the AuthTag +# somewhere +proc registerRequest(d: Protocol, n: Node, packet: seq[byte], nonce: AuthTag) = + let request = PendingRequest(node: n, packet: packet) + if not d.pendingRequests.hasKeyOrPut(nonce, request): + sleepAsync(responseTimeout).addCallback() do(data: pointer): + d.pendingRequests.del(nonce) + proc waitPacket(d: Protocol, fromNode: Node, reqId: RequestId): Future[Option[Packet]] = result = newFuture[Option[Packet]]("waitPacket") let res = result @@ -287,14 +293,23 @@ proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): Future[seq[Node]] else: break -# TODO: This could be improved to do the clean-up immediatily in case a non -# whoareyou response does arrive, but we would need to store the AuthTag -# somewhere -proc registerRequest(d: Protocol, n: Node, packet: seq[byte], nonce: AuthTag) = - let request = PendingRequest(node: n, packet: packet) - if not d.pendingRequests.hasKeyOrPut(nonce, request): - sleepAsync(responseTimeout).addCallback() do(data: pointer): - d.pendingRequests.del(nonce) +proc sendPing(d: Protocol, toNode: Node): RequestId = + let + reqId = newRequestId() + ping = PingPacket(enrSeq: d.localNode.record.seqNum) + packet = encodePacket(ping, reqId) + (data, nonce) = d.codec.encodeEncrypted(toNode.id, toNode.address, packet, + challenge = nil) + d.registerRequest(toNode, packet, nonce) + d.send(toNode, data) + return reqId + +proc ping(d: Protocol, toNode: Node): Future[Option[PongPacket]] {.async.} = + let reqId = d.sendPing(toNode) + let resp = await d.waitPacket(toNode, reqId) + + if resp.isSome() and resp.get().kind == pong: + return some(resp.get().pong) proc sendFindNode(d: Protocol, toNode: Node, distance: uint32): RequestId = let reqId = newRequestId() @@ -321,26 +336,26 @@ proc lookupDistances(target, dest: NodeId): seq[uint32] = result.add(td - i) inc i -proc lookupWorker(p: Protocol, destNode: Node, target: NodeId): Future[seq[Node]] {.async.} = +proc lookupWorker(d: Protocol, destNode: Node, target: NodeId): Future[seq[Node]] {.async.} = let dists = lookupDistances(target, destNode.id) var i = 0 while i < lookupRequestLimit and result.len < findNodeResultLimit: # TODO: Handle failures - let r = await p.findNode(destNode, dists[i]) + let r = await d.findNode(destNode, dists[i]) # TODO: I guess it makes sense to limit here also to `findNodeResultLimit`? result.add(r) inc i for n in result: - discard p.routingTable.addNode(n) + discard d.routingTable.addNode(n) -proc lookup*(p: Protocol, target: NodeId): Future[seq[Node]] {.async.} = +proc lookup*(d: Protocol, target: NodeId): Future[seq[Node]] {.async.} = ## Perform a lookup for the given target, return the closest n nodes to the ## target. Maximum value for n is `BUCKET_SIZE`. # TODO: Sort the returned nodes on distance - result = p.routingTable.neighbours(target, BUCKET_SIZE) + result = d.routingTable.neighbours(target, BUCKET_SIZE) var asked = initHashSet[NodeId]() - asked.incl(p.localNode.id) + asked.incl(d.localNode.id) var seen = asked for node in result: seen.incl(node.id) @@ -352,7 +367,7 @@ proc lookup*(p: Protocol, target: NodeId): Future[seq[Node]] {.async.} = while i < result.len and pendingQueries.len < alpha: let n = result[i] if not asked.containsOrIncl(n.id): - pendingQueries.add(p.lookupWorker(n, target)) + pendingQueries.add(d.lookupWorker(n, target)) inc i trace "discv5 pending queries", total = pendingQueries.len @@ -370,50 +385,20 @@ proc lookup*(p: Protocol, target: NodeId): Future[seq[Node]] {.async.} = if result.len < BUCKET_SIZE: result.add(n) -proc lookupRandom*(p: Protocol): Future[seq[Node]] +proc lookupRandom*(d: Protocol): Future[seq[Node]] {.raises:[RandomSourceDepleted, Defect, Exception].} = var id: NodeId if randomBytes(addr id, sizeof(id)) != sizeof(id): raise newException(RandomSourceDepleted, "Could not randomize bytes") - p.lookup(id) - -proc processClient(transp: DatagramTransport, - raddr: TransportAddress): Future[void] {.async, gcsafe.} = - var proto = getUserData[Protocol](transp) - try: - # TODO: Maybe here better to use `peekMessage()` to avoid allocation, - # but `Bytes` object is just a simple seq[byte], and `ByteRange` object - # do not support custom length. - var buf = transp.getMessage() - let a = Address(ip: raddr.address, udpPort: raddr.port, tcpPort: raddr.port) - proto.receive(a, buf) - except RlpError as e: - debug "Receive failed", exception = e.name, msg = e.msg - # TODO: what else can be raised? Figure this out and be more restrictive? - except CatchableError as e: - debug "Receive failed", exception = e.name, msg = e.msg, - stacktrace = e.getStackTrace() - -proc sendPing(d: Protocol, toNode: Node): RequestId = - let - reqId = newRequestId() - ping = PingPacket(enrSeq: d.localNode.record.seqNum) - packet = encodePacket(ping, reqId) - (data, nonce) = d.codec.encodeEncrypted(toNode.id, toNode.address, packet, - challenge = nil) - d.registerRequest(toNode, packet, nonce) - d.send(toNode, data) - return reqId + d.lookup(id) proc revalidateNode(d: Protocol, n: Node) {.async, raises:[Defect, Exception].} = # TODO: Exception trace "Ping to revalidate node", node = $n - let reqId = d.sendPing(n) + let pong = await d.ping(n) - let resp = await d.waitPacket(n, reqId) - if resp.isSome and resp.get.kind == pong: - let pong = resp.get.pong - if pong.enrSeq > n.record.seqNum: + if pong.isSome(): + if pong.get().enrSeq > n.record.seqNum: # TODO: Request new ENR discard @@ -461,6 +446,26 @@ proc lookupLoop(d: Protocol) {.async.} = except CancelledError: trace "lookupLoop canceled" +proc newProtocol*(privKey: PrivateKey, db: Database, + ip: IpAddress, tcpPort, udpPort: Port, + bootstrapRecords: openarray[Record] = []): Protocol = + let + a = Address(ip: ip, tcpPort: tcpPort, udpPort: udpPort) + enode = initENode(privKey.getPublicKey(), a) + enrRec = enr.Record.init(12, privKey, some(a)) + node = newNode(enode, enrRec) + + result = Protocol( + privateKey: privKey, + db: db, + localNode: node, + whoareyouMagic: whoareyouMagic(node.id), + idHash: sha256.digest(node.id.toByteArrayBE).data, + codec: Codec(localNode: node, privKey: privKey, db: db), + bootstrapNodes: newNodes(bootstrapRecords)) + + result.routingTable.init(node) + proc open*(d: Protocol) = debug "Starting discovery node", node = $d.localNode, uri = toURI(d.localNode.record) @@ -498,48 +503,3 @@ proc closeWait*(d: Protocol) {.async.} = await d.lookupLoop.cancelAndWait() await d.transp.closeWait() - -when isMainModule: - import discovery_db - import eth/trie/db - - proc genDiscoveries(n: int): seq[Protocol] = - var pks = ["98b3d4d4fe348ac5192d16b46aa36c41f847b9f265ba4d56f6326669449a968b", "88d125288fbb19ecd7b6a355faf3e842e3c6158d38af14bb97ac8d957ec9cb58", "c9a24471d2f84efa103b9abbdedd4c0fea8402f94e5ceb3ca4d9cff951fc407f"] - for i in 0 ..< n: - var pk: PrivateKey - if i < pks.len: - pk = initPrivateKey(pks[i]) - else: - pk = newPrivateKey() - - let d = newProtocol(pk, DiscoveryDB.init(newMemoryDB()), - parseIpAddress("127.0.0.1"), Port(12001 + i), Port(12001 + i)) - d.open() - result.add(d) - - proc addNode(d: openarray[Protocol], enr: string) = - for dd in d: dd.addNode(EnrUri(enr)) - - proc test() {.async.} = - block: - let d = genDiscoveries(3) - d.addNode("enr:-IS4QPvi3TdAUd2Jdrx-8ScRbCzrV1kVsTTM02mfz8Fx7CtrAfYN7AjxTx3MWbY2efRmAhS-Yyv4nhyzKu_YS6jSh08BgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQJeWTAJhJYN2q3BvcQwsyo7pIi8KnfwDIrhNdflCFvqr4N1ZHCCD6A") - - for i, dd in d: - let nodes = await dd.lookupRandom() - echo "NODES ", i, ": ", nodes - - # block: - # var d = genDiscoveries(4) - # let rootD = d[0] - # d.del(0) - - - # d.addNode(rootD.localNode.record.toUri) - - # for i, dd in d: - # let nodes = await dd.lookupRandom() - # echo "NODES ", i, ": ", nodes - - waitFor test() - runForever()