diff --git a/eth/p2p/discoveryv5/enr.nim b/eth/p2p/discoveryv5/enr.nim index d52c2a5..1ff34e7 100644 --- a/eth/p2p/discoveryv5/enr.nim +++ b/eth/p2p/discoveryv5/enr.nim @@ -101,15 +101,19 @@ macro initRecord*(seqNum: uint64, pk: PrivateKey, pairs: untyped{nkTableConstr}) proc init*(T: type Record, seqNum: uint64, pk: PrivateKey, - address: enode.Address): T = - let - isV6 = address.ip.family == IPv6 - ipField = if isV6: ("ip6", address.ip.address_v6.toField) - else: ("ip", address.ip.address_v4.toField) - tcpField = ((if isV6: "tcp6" else: "tcp"), address.tcpPort.uint16.toField) - udpField = ((if isV6: "udp6" else: "udp"), address.udpPort.uint16.toField) + address: Option[enode.Address]): T = + if address.isSome(): + let + a = address.get() + isV6 = a.ip.family == IPv6 + ipField = if isV6: ("ip6", a.ip.address_v6.toField) + else: ("ip", a.ip.address_v4.toField) + tcpField = ((if isV6: "tcp6" else: "tcp"), a.tcpPort.uint16.toField) + udpField = ((if isV6: "udp6" else: "udp"), a.udpPort.uint16.toField) - makeEnrAux(seqNum, pk, [ipField, tcpField, udpField]) + makeEnrAux(seqNum, pk, [ipField, tcpField, udpField]) + else: + makeEnrAux(seqNum, pk, []) proc getField(r: Record, name: string, field: var Field): bool = # It might be more correct to do binary search, diff --git a/eth/p2p/discoveryv5/node.nim b/eth/p2p/discoveryv5/node.nim index 199b00a..9a0d8be 100644 --- a/eth/p2p/discoveryv5/node.nim +++ b/eth/p2p/discoveryv5/node.nim @@ -28,19 +28,23 @@ proc newNode*(pk: PublicKey, address: Address): Node = proc newNode*(r: Record): Node = # TODO: Handle IPv6 - let - ipBytes = r.get("ip", array[4, byte]) - udpPort = r.get("udp", uint16) + var a: Address + try: + let + ipBytes = r.get("ip", array[4, byte]) + udpPort = r.get("udp", uint16) + + a = Address(ip: IpAddress(family: IpAddressFamily.IPv4, + address_v4: ipBytes), + udpPort: Port udpPort) + except KeyError: + discard var pk: PublicKey if recoverPublicKey(r.get("secp256k1", seq[byte]), pk) != EthKeysStatus.Success: warn "Could not recover public key" return - let a = Address(ip: IpAddress(family: IpAddressFamily.IPv4, - address_v4: ipBytes), - udpPort: Port udpPort) - result = newNode(initENode(pk, a)) result.record = r diff --git a/eth/p2p/discoveryv5/protocol.nim b/eth/p2p/discoveryv5/protocol.nim index 5964fe4..d7dc0ae 100644 --- a/eth/p2p/discoveryv5/protocol.nim +++ b/eth/p2p/discoveryv5/protocol.nim @@ -59,10 +59,16 @@ proc addNode*(d: Protocol, enr: EnrUri) = doAssert(res) d.addNode newNode(r) -proc randomNodes*(k: Protocol, count: int): seq[Node] = - k.routingTable.randomNodes(count) +proc getNode*(d: Protocol, id: NodeId): Node = + d.routingTable.getNode(id) -proc nodesDiscovered*(k: Protocol): int {.inline.} = k.routingTable.len +proc randomNodes*(d: Protocol, count: int): seq[Node] = + d.routingTable.randomNodes(count) + +proc neighbours*(d: Protocol, id: NodeId, k: int = BUCKET_SIZE): seq[Node] = + d.routingTable.neighbours(id, k) + +proc nodesDiscovered*(d: Protocol): int {.inline.} = d.routingTable.len proc whoareyouMagic(toNode: NodeId): array[magicSize, byte] = const prefix = "WHOAREYOU" @@ -77,7 +83,7 @@ proc newProtocol*(privKey: PrivateKey, db: Database, let a = Address(ip: ip, tcpPort: tcpPort, udpPort: udpPort) enode = initENode(privKey.getPublicKey(), a) - enrRec = enr.Record.init(12, privKey, a) + enrRec = enr.Record.init(12, privKey, some(a)) node = newNode(enode, enrRec) result = Protocol( @@ -281,7 +287,7 @@ proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): Future[seq[Node]] else: break -proc findNode(d: Protocol, toNode: Node, distance: uint32): Future[seq[Node]] {.async.} = +proc findNode*(d: Protocol, toNode: Node, distance: uint32): Future[seq[Node]] {.async.} = let reqId = newRequestId() let packet = encodePacket(FindNodePacket(distance: distance), reqId) let (data, nonce) = d.codec.encodeEncrypted(toNode.id, toNode.address, packet, diff --git a/tests/p2p/test_discoveryv5.nim b/tests/p2p/test_discoveryv5.nim index 42d442d..6ea5e07 100644 --- a/tests/p2p/test_discoveryv5.nim +++ b/tests/p2p/test_discoveryv5.nim @@ -1,5 +1,5 @@ import - random, unittest, chronos, sequtils, chronicles, tables, stint, + random, unittest, chronos, sequtils, chronicles, tables, stint, options, eth/[keys, rlp], eth/p2p/enode, eth/trie/db, eth/p2p/discoveryv5/[discovery_db, enr, node, types, routing_table, encoding], eth/p2p/discoveryv5/protocol as discv5_protocol, @@ -32,6 +32,10 @@ proc randomPacket(tag: PacketTag): seq[byte] = result.add(rlp.encode(authTag)) result.add(msg) +proc generateNode(privKey = newPrivateKey()): Node = + let enr = enr.Record.init(1, privKey, none(Address)) + result = newNode(enr) + suite "Discovery v5 Tests": asyncTest "Random nodes": let @@ -88,6 +92,47 @@ suite "Discovery v5 Tests": for node in nodes: await node.closeWait() + asyncTest "FindNode with test table": + + let mainNode = initDiscoveryNode(newPrivateKey(), localAddress(20301), @[]) + + # Generate 1000 random nodes and add to our main node's routing table + for i in 0..<1000: + mainNode.addNode(generateNode()) + + let + neighbours = mainNode.neighbours(mainNode.localNode.id) + closest = neighbours[0] + closestDistance = logDist(closest.id, mainNode.localNode.id) + + debug "Closest neighbour", closestDistance, id=closest.id.toHex() + + let + testNode = initDiscoveryNode(newPrivateKey(), localAddress(20302), + @[mainNode.localNode.record]) + discovered = await discv5_protocol.findNode(testNode, mainNode.localNode, + closestDistance) + + check closest in discovered + + await mainNode.closeWait() + await testNode.closeWait() + + asyncTest "GetNode": + # TODO: This could be tested in just a routing table only context + let + node = initDiscoveryNode(newPrivateKey(), localAddress(20302), @[]) + targetNode = generateNode() + + node.addNode(targetNode) + + for i in 0..<1000: + node.addNode(generateNode()) + + check node.getNode(targetNode.id) == targetNode + + await node.closeWait() + asyncTest "Handshake cleanup": let node = initDiscoveryNode(newPrivateKey(), localAddress(20302), @[]) var tag: PacketTag diff --git a/tests/p2p/test_enr.nim b/tests/p2p/test_enr.nim index 92d2dfb..d275efe 100644 --- a/tests/p2p/test_enr.nim +++ b/tests/p2p/test_enr.nim @@ -35,19 +35,35 @@ suite "ENR": keys = newKeyPair() ip = parseIpAddress("10.20.30.40") enodeAddress = Address(ip: ip, tcpPort: Port 9000, udpPort: Port 9000) - enr = Record.init(100, keys.seckey, enodeAddress) - typedEnr = get enr.toTypedRecord + enr = Record.init(100, keys.seckey, some(enodeAddress)) + typedEnr = get enr.toTypedRecord() check: - typedEnr.secp256k1.isSome - typedEnr.secp256k1.get == keys.pubkey.getRawCompressed + typedEnr.secp256k1.isSome() + typedEnr.secp256k1.get == keys.pubkey.getRawCompressed() - typedEnr.ip.isSome - typedEnr.ip.get == [byte 10, 20, 30, 40] + typedEnr.ip.isSome() + typedEnr.ip.get() == [byte 10, 20, 30, 40] - typedEnr.tcp.isSome - typedEnr.tcp.get == 9000 + typedEnr.tcp.isSome() + typedEnr.tcp.get() == 9000 - typedEnr.udp.isSome - typedEnr.udp.get == 9000 + typedEnr.udp.isSome() + typedEnr.udp.get() == 9000 + test "ENR without address": + let + keys = newKeyPair() + enr = Record.init(100, keys.seckey, none(Address)) + typedEnr = get enr.toTypedRecord() + + check: + typedEnr.secp256k1.isSome() + typedEnr.secp256k1.get() == keys.pubkey.getRawCompressed() + + typedEnr.ip.isNone() + typedEnr.tcp.isNone() + typedEnr.udp.isNone() + typedEnr.ip6.isNone() + typedEnr.tcp6.isNone() + typedEnr.udp6.isNone()