diff --git a/eth/p2p/discoveryv5/dcli.nim b/eth/p2p/discoveryv5/dcli.nim index b2f086a..d95ad5d 100644 --- a/eth/p2p/discoveryv5/dcli.nim +++ b/eth/p2p/discoveryv5/dcli.nim @@ -1,6 +1,6 @@ import sequtils, options, strutils, chronos, chronicles, chronicles/topics_registry, - stew/byteutils, confutils, + stew/byteutils, stew/shims/net, confutils, eth/keys, eth/trie/db, eth/net/nat, eth/p2p/discoveryv5/[protocol, discovery_db, enr, node] @@ -76,7 +76,7 @@ proc parseCmdArg*(T: type Node, p: TaintedString): T = proc completeCmdArg*(T: type Node, val: TaintedString): seq[string] = return @[] -proc setupNat(conf: DiscoveryConf): tuple[ip: Option[IpAddress], +proc setupNat(conf: DiscoveryConf): tuple[ip: Option[ValidIpAddress], tcpPort: Port, udpPort: Port] {.gcsafe.} = # defaults @@ -96,15 +96,16 @@ proc setupNat(conf: DiscoveryConf): tuple[ip: Option[IpAddress], else: if conf.nat.startsWith("extip:") and isIpAddress(conf.nat[6..^1]): # any required port redirection is assumed to be done by hand - result.ip = some(parseIpAddress(conf.nat[6..^1])) + result.ip = some(ValidIpAddress.init(conf.nat[6..^1])) nat = NatNone else: error "not a valid NAT mechanism, nor a valid IP address", value = conf.nat quit(QuitFailure) if nat != NatNone: - result.ip = getExternalIP(nat) - if result.ip.isSome: + let extIp = getExternalIP(nat) + if extIP.isSome: + result.ip = some(ValidIpAddress.init extIp.get) let extPorts = ({.gcsafe.}: redirectPorts(tcpPort = result.tcpPort, udpPort = result.udpPort, diff --git a/eth/p2p/discoveryv5/enr.nim b/eth/p2p/discoveryv5/enr.nim index c0676e3..92c6ce8 100644 --- a/eth/p2p/discoveryv5/enr.nim +++ b/eth/p2p/discoveryv5/enr.nim @@ -217,6 +217,14 @@ proc toTypedRecord*(r: Record): EnrResult[TypedRecord] = else: err("Record without id field") +proc contains*(r: Record, fp: (string, seq[byte])): bool = + # TODO: use FieldPair for this, but that is a bit cumbersome. Perhaps the + # `get` call can be improved to make this easier. + let field = r.tryGet(fp[0], seq[byte]) + if field.isSome(): + if field.get() == fp[1]: + return true + proc verifySignatureV4(r: Record, sigData: openarray[byte], content: seq[byte]): bool = let publicKey = r.get(PublicKey) diff --git a/eth/p2p/discoveryv5/node.nim b/eth/p2p/discoveryv5/node.nim index c4a581b..a4c8598 100644 --- a/eth/p2p/discoveryv5/node.nim +++ b/eth/p2p/discoveryv5/node.nim @@ -31,10 +31,7 @@ proc newNode*(r: Record): Result[Node, cstring] = let tr = ? r.toTypedRecord() if tr.ip.isSome() and tr.udp.isSome(): - let - ip = ValidIpAddress.init( - IpAddress(family: IpAddressFamily.IPv4, address_v4: tr.ip.get())) - a = Address(ip: ip, port: Port(tr.udp.get())) + let a = Address(ip: ipv4(tr.ip.get()), port: Port(tr.udp.get())) ok(Node(id: pk.get().toNodeId(), pubkey: pk.get() , record: r, address: some(a))) diff --git a/eth/p2p/discoveryv5/protocol.nim b/eth/p2p/discoveryv5/protocol.nim index 9ae685e..0689e33 100644 --- a/eth/p2p/discoveryv5/protocol.nim +++ b/eth/p2p/discoveryv5/protocol.nim @@ -145,8 +145,21 @@ proc addNode*(d: Protocol, enr: EnrUri): bool = proc getNode*(d: Protocol, id: NodeId): Option[Node] = d.routingTable.getNode(id) -proc randomNodes*(d: Protocol, count: int): seq[Node] = - d.routingTable.randomNodes(count) +proc randomNodes*(d: Protocol, maxAmount: int): seq[Node] = + ## Get a `maxAmount` of random nodes from the local routing table. + d.routingTable.randomNodes(maxAmount) + +proc randomNodes*(d: Protocol, maxAmount: int, + pred: proc(x: Node): bool {.gcsafe, noSideEffect.}): seq[Node] = + ## Get a `maxAmount` of random nodes from the local routing table with the + ## `pred` predicate function applied as filter on the nodes selected. + d.routingTable.randomNodes(maxAmount, pred) + +proc randomNodes*(d: Protocol, maxAmount: int, + enrField: (string, seq[byte])): seq[Node] = + ## Get a `maxAmount` of random nodes from the local routing table. The + ## the nodes selected are filtered by provided `enrField`. + d.randomNodes(maxAmount, proc(x: Node): bool = x.record.contains(enrField)) proc neighbours*(d: Protocol, id: NodeId, k: int = BUCKET_SIZE): seq[Node] = d.routingTable.neighbours(id, k) @@ -598,6 +611,8 @@ proc lookup*(d: Protocol, target: NodeId): Future[seq[Node]] proc lookupRandom*(d: Protocol): Future[DiscResult[seq[Node]]] {.async, raises:[Exception, Defect].} = + ## Perform a lookup for a random target, return the closest n nodes to the + ## target. Maximum value for n is `BUCKET_SIZE`. var id: NodeId if randomBytes(addr id, sizeof(id)) != sizeof(id): return err("Could not randomize bytes") diff --git a/eth/p2p/discoveryv5/routing_table.nim b/eth/p2p/discoveryv5/routing_table.nim index 9976c21..d63a930 100644 --- a/eth/p2p/discoveryv5/routing_table.nim +++ b/eth/p2p/discoveryv5/routing_table.nim @@ -248,25 +248,28 @@ proc nodeToRevalidate*(r: RoutingTable): Node = if b.len > 0: return b.nodes[^1] -proc randomNodes*(r: RoutingTable, count: int): seq[Node] = - var count = count +proc randomNodes*(r: RoutingTable, maxAmount: int, + pred: proc(x: Node): bool {.gcsafe, noSideEffect.} = nil): seq[Node] = + var maxAmount = maxAmount let sz = r.len - if count > sz: - debug "Looking for peers", requested = count, present = sz - count = sz + if maxAmount > sz: + debug "Less peers in routing table than maximum requested", + requested = maxAmount, present = sz + maxAmount = sz - result = newSeqOfCap[Node](count) + result = newSeqOfCap[Node](maxAmount) var seen = initHashSet[Node]() # This is a rather inneficient way of randomizing nodes from all buckets, but even if we # iterate over all nodes in the routing table, the time it takes would still be # insignificant compared to the time it takes for the network roundtrips when connecting # to nodes. - while len(seen) < count: + while len(seen) < maxAmount: # TODO: Is it important to get a better random source for these sample calls? let bucket = sample(r.buckets) if bucket.nodes.len != 0: let node = sample(bucket.nodes) if node notin seen: - result.add(node) seen.incl(node) + if pred.isNil() or node.pred: + result.add(node) diff --git a/tests/p2p/test_discoveryv5.nim b/tests/p2p/test_discoveryv5.nim index 67907f3..d5f6203 100644 --- a/tests/p2p/test_discoveryv5.nim +++ b/tests/p2p/test_discoveryv5.nim @@ -9,13 +9,15 @@ proc localAddress*(port: int): Address = Address(ip: ValidIpAddress.init("127.0.0.1"), port: Port(port)) proc initDiscoveryNode*(privKey: PrivateKey, address: Address, - bootstrapRecords: openarray[Record] = []): + bootstrapRecords: openarray[Record] = [], + localEnrFields: openarray[FieldPair] = []): discv5_protocol.Protocol = var db = DiscoveryDB.init(newMemoryDB()) result = newProtocol(privKey, db, some(address.ip), address.port, address.port, - bootstrapRecords = bootstrapRecords) + bootstrapRecords = bootstrapRecords, + localEnrFields = localEnrFields) result.open() @@ -35,10 +37,11 @@ proc randomPacket(tag: PacketTag): seq[byte] = result.add(rlp.encode(authTag)) result.add(msg) -proc generateNode(privKey = PrivateKey.random()[], port: int = 20302): Node = +proc generateNode(privKey = PrivateKey.random()[], port: int = 20302, + localEnrFields: openarray[FieldPair] = []): Node = let port = Port(port) let enr = enr.Record.init(1, privKey, some(ValidIpAddress.init("127.0.0.1")), - port, port).expect("Properly intialized private key") + port, port, localEnrFields).expect("Properly intialized private key") result = newNode(enr).expect("Properly initialized node") proc nodeAtDistance(n: Node, d: uint32): Node = @@ -387,3 +390,25 @@ suite "Discovery v5 Tests": await mainNode.closeWait() await lookupNode.closeWait() + + asyncTest "Random nodes with enr field filter": + let + lookupNode = initDiscoveryNode(PrivateKey.random()[], localAddress(20301)) + targetFieldPair = toFieldPair("test", @[byte 1,2,3,4]) + targetNode = generateNode(localEnrFields = [targetFieldPair]) + otherFieldPair = toFieldPair("test", @[byte 1,2,3,4,5]) + otherNode = generateNode(localEnrFields = [otherFieldPair]) + anotherNode = generateNode() + + check: + lookupNode.addNode(targetNode) + lookupNode.addNode(otherNode) + lookupNode.addNode(anotherNode) + + let discovered = lookupNode.randomNodes(10) + check discovered.len == 3 + let discoveredFiltered = lookupNode.randomNodes(10, + ("test", @[byte 1,2,3,4])) + check discoveredFiltered.len == 1 and discoveredFiltered.contains(targetNode) + + await lookupNode.closeWait()