Add predicate filter option for randomNodes (#251)

* Add predicate filter option for randomNodes

* Further ValidIpAddress fixes

* Add gcsafe/noSideEffect and add test case
This commit is contained in:
Kim De Mey 2020-06-11 21:24:52 +02:00 committed by GitHub
parent be9a87848e
commit 225a9ad41c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 72 additions and 23 deletions

View File

@ -1,6 +1,6 @@
import import
sequtils, options, strutils, chronos, chronicles, chronicles/topics_registry, sequtils, options, strutils, chronos, chronicles, chronicles/topics_registry,
stew/byteutils, confutils, stew/byteutils, stew/shims/net, confutils,
eth/keys, eth/trie/db, eth/net/nat, eth/keys, eth/trie/db, eth/net/nat,
eth/p2p/discoveryv5/[protocol, discovery_db, enr, node] eth/p2p/discoveryv5/[protocol, discovery_db, enr, node]
@ -76,7 +76,7 @@ proc parseCmdArg*(T: type Node, p: TaintedString): T =
proc completeCmdArg*(T: type Node, val: TaintedString): seq[string] = proc completeCmdArg*(T: type Node, val: TaintedString): seq[string] =
return @[] return @[]
proc setupNat(conf: DiscoveryConf): tuple[ip: Option[IpAddress], proc setupNat(conf: DiscoveryConf): tuple[ip: Option[ValidIpAddress],
tcpPort: Port, tcpPort: Port,
udpPort: Port] {.gcsafe.} = udpPort: Port] {.gcsafe.} =
# defaults # defaults
@ -96,15 +96,16 @@ proc setupNat(conf: DiscoveryConf): tuple[ip: Option[IpAddress],
else: else:
if conf.nat.startsWith("extip:") and isIpAddress(conf.nat[6..^1]): if conf.nat.startsWith("extip:") and isIpAddress(conf.nat[6..^1]):
# any required port redirection is assumed to be done by hand # any required port redirection is assumed to be done by hand
result.ip = some(parseIpAddress(conf.nat[6..^1])) result.ip = some(ValidIpAddress.init(conf.nat[6..^1]))
nat = NatNone nat = NatNone
else: else:
error "not a valid NAT mechanism, nor a valid IP address", value = conf.nat error "not a valid NAT mechanism, nor a valid IP address", value = conf.nat
quit(QuitFailure) quit(QuitFailure)
if nat != NatNone: if nat != NatNone:
result.ip = getExternalIP(nat) let extIp = getExternalIP(nat)
if result.ip.isSome: if extIP.isSome:
result.ip = some(ValidIpAddress.init extIp.get)
let extPorts = ({.gcsafe.}: let extPorts = ({.gcsafe.}:
redirectPorts(tcpPort = result.tcpPort, redirectPorts(tcpPort = result.tcpPort,
udpPort = result.udpPort, udpPort = result.udpPort,

View File

@ -217,6 +217,14 @@ proc toTypedRecord*(r: Record): EnrResult[TypedRecord] =
else: else:
err("Record without id field") err("Record without id field")
proc contains*(r: Record, fp: (string, seq[byte])): bool =
# TODO: use FieldPair for this, but that is a bit cumbersome. Perhaps the
# `get` call can be improved to make this easier.
let field = r.tryGet(fp[0], seq[byte])
if field.isSome():
if field.get() == fp[1]:
return true
proc verifySignatureV4(r: Record, sigData: openarray[byte], content: seq[byte]): proc verifySignatureV4(r: Record, sigData: openarray[byte], content: seq[byte]):
bool = bool =
let publicKey = r.get(PublicKey) let publicKey = r.get(PublicKey)

View File

@ -31,10 +31,7 @@ proc newNode*(r: Record): Result[Node, cstring] =
let tr = ? r.toTypedRecord() let tr = ? r.toTypedRecord()
if tr.ip.isSome() and tr.udp.isSome(): if tr.ip.isSome() and tr.udp.isSome():
let let a = Address(ip: ipv4(tr.ip.get()), port: Port(tr.udp.get()))
ip = ValidIpAddress.init(
IpAddress(family: IpAddressFamily.IPv4, address_v4: tr.ip.get()))
a = Address(ip: ip, port: Port(tr.udp.get()))
ok(Node(id: pk.get().toNodeId(), pubkey: pk.get() , record: r, ok(Node(id: pk.get().toNodeId(), pubkey: pk.get() , record: r,
address: some(a))) address: some(a)))

View File

@ -145,8 +145,21 @@ proc addNode*(d: Protocol, enr: EnrUri): bool =
proc getNode*(d: Protocol, id: NodeId): Option[Node] = proc getNode*(d: Protocol, id: NodeId): Option[Node] =
d.routingTable.getNode(id) d.routingTable.getNode(id)
proc randomNodes*(d: Protocol, count: int): seq[Node] = proc randomNodes*(d: Protocol, maxAmount: int): seq[Node] =
d.routingTable.randomNodes(count) ## Get a `maxAmount` of random nodes from the local routing table.
d.routingTable.randomNodes(maxAmount)
proc randomNodes*(d: Protocol, maxAmount: int,
pred: proc(x: Node): bool {.gcsafe, noSideEffect.}): seq[Node] =
## Get a `maxAmount` of random nodes from the local routing table with the
## `pred` predicate function applied as filter on the nodes selected.
d.routingTable.randomNodes(maxAmount, pred)
proc randomNodes*(d: Protocol, maxAmount: int,
enrField: (string, seq[byte])): seq[Node] =
## Get a `maxAmount` of random nodes from the local routing table. The
## the nodes selected are filtered by provided `enrField`.
d.randomNodes(maxAmount, proc(x: Node): bool = x.record.contains(enrField))
proc neighbours*(d: Protocol, id: NodeId, k: int = BUCKET_SIZE): seq[Node] = proc neighbours*(d: Protocol, id: NodeId, k: int = BUCKET_SIZE): seq[Node] =
d.routingTable.neighbours(id, k) d.routingTable.neighbours(id, k)
@ -598,6 +611,8 @@ proc lookup*(d: Protocol, target: NodeId): Future[seq[Node]]
proc lookupRandom*(d: Protocol): Future[DiscResult[seq[Node]]] proc lookupRandom*(d: Protocol): Future[DiscResult[seq[Node]]]
{.async, raises:[Exception, Defect].} = {.async, raises:[Exception, Defect].} =
## Perform a lookup for a random target, return the closest n nodes to the
## target. Maximum value for n is `BUCKET_SIZE`.
var id: NodeId var id: NodeId
if randomBytes(addr id, sizeof(id)) != sizeof(id): if randomBytes(addr id, sizeof(id)) != sizeof(id):
return err("Could not randomize bytes") return err("Could not randomize bytes")

View File

@ -248,25 +248,28 @@ proc nodeToRevalidate*(r: RoutingTable): Node =
if b.len > 0: if b.len > 0:
return b.nodes[^1] return b.nodes[^1]
proc randomNodes*(r: RoutingTable, count: int): seq[Node] = proc randomNodes*(r: RoutingTable, maxAmount: int,
var count = count pred: proc(x: Node): bool {.gcsafe, noSideEffect.} = nil): seq[Node] =
var maxAmount = maxAmount
let sz = r.len let sz = r.len
if count > sz: if maxAmount > sz:
debug "Looking for peers", requested = count, present = sz debug "Less peers in routing table than maximum requested",
count = sz requested = maxAmount, present = sz
maxAmount = sz
result = newSeqOfCap[Node](count) result = newSeqOfCap[Node](maxAmount)
var seen = initHashSet[Node]() var seen = initHashSet[Node]()
# This is a rather inneficient way of randomizing nodes from all buckets, but even if we # This is a rather inneficient way of randomizing nodes from all buckets, but even if we
# iterate over all nodes in the routing table, the time it takes would still be # iterate over all nodes in the routing table, the time it takes would still be
# insignificant compared to the time it takes for the network roundtrips when connecting # insignificant compared to the time it takes for the network roundtrips when connecting
# to nodes. # to nodes.
while len(seen) < count: while len(seen) < maxAmount:
# TODO: Is it important to get a better random source for these sample calls? # TODO: Is it important to get a better random source for these sample calls?
let bucket = sample(r.buckets) let bucket = sample(r.buckets)
if bucket.nodes.len != 0: if bucket.nodes.len != 0:
let node = sample(bucket.nodes) let node = sample(bucket.nodes)
if node notin seen: if node notin seen:
result.add(node)
seen.incl(node) seen.incl(node)
if pred.isNil() or node.pred:
result.add(node)

View File

@ -9,13 +9,15 @@ proc localAddress*(port: int): Address =
Address(ip: ValidIpAddress.init("127.0.0.1"), port: Port(port)) Address(ip: ValidIpAddress.init("127.0.0.1"), port: Port(port))
proc initDiscoveryNode*(privKey: PrivateKey, address: Address, proc initDiscoveryNode*(privKey: PrivateKey, address: Address,
bootstrapRecords: openarray[Record] = []): bootstrapRecords: openarray[Record] = [],
localEnrFields: openarray[FieldPair] = []):
discv5_protocol.Protocol = discv5_protocol.Protocol =
var db = DiscoveryDB.init(newMemoryDB()) var db = DiscoveryDB.init(newMemoryDB())
result = newProtocol(privKey, db, result = newProtocol(privKey, db,
some(address.ip), some(address.ip),
address.port, address.port, address.port, address.port,
bootstrapRecords = bootstrapRecords) bootstrapRecords = bootstrapRecords,
localEnrFields = localEnrFields)
result.open() result.open()
@ -35,10 +37,11 @@ proc randomPacket(tag: PacketTag): seq[byte] =
result.add(rlp.encode(authTag)) result.add(rlp.encode(authTag))
result.add(msg) result.add(msg)
proc generateNode(privKey = PrivateKey.random()[], port: int = 20302): Node = proc generateNode(privKey = PrivateKey.random()[], port: int = 20302,
localEnrFields: openarray[FieldPair] = []): Node =
let port = Port(port) let port = Port(port)
let enr = enr.Record.init(1, privKey, some(ValidIpAddress.init("127.0.0.1")), let enr = enr.Record.init(1, privKey, some(ValidIpAddress.init("127.0.0.1")),
port, port).expect("Properly intialized private key") port, port, localEnrFields).expect("Properly intialized private key")
result = newNode(enr).expect("Properly initialized node") result = newNode(enr).expect("Properly initialized node")
proc nodeAtDistance(n: Node, d: uint32): Node = proc nodeAtDistance(n: Node, d: uint32): Node =
@ -387,3 +390,25 @@ suite "Discovery v5 Tests":
await mainNode.closeWait() await mainNode.closeWait()
await lookupNode.closeWait() await lookupNode.closeWait()
asyncTest "Random nodes with enr field filter":
let
lookupNode = initDiscoveryNode(PrivateKey.random()[], localAddress(20301))
targetFieldPair = toFieldPair("test", @[byte 1,2,3,4])
targetNode = generateNode(localEnrFields = [targetFieldPair])
otherFieldPair = toFieldPair("test", @[byte 1,2,3,4,5])
otherNode = generateNode(localEnrFields = [otherFieldPair])
anotherNode = generateNode()
check:
lookupNode.addNode(targetNode)
lookupNode.addNode(otherNode)
lookupNode.addNode(anotherNode)
let discovered = lookupNode.randomNodes(10)
check discovered.len == 3
let discoveredFiltered = lookupNode.randomNodes(10,
("test", @[byte 1,2,3,4]))
check discoveredFiltered.len == 1 and discoveredFiltered.contains(targetNode)
await lookupNode.closeWait()