mirror of https://github.com/status-im/nim-eth.git
Add predicate filter option for randomNodes (#251)
* Add predicate filter option for randomNodes * Further ValidIpAddress fixes * Add gcsafe/noSideEffect and add test case
This commit is contained in:
parent
be9a87848e
commit
225a9ad41c
|
@ -1,6 +1,6 @@
|
|||
import
|
||||
sequtils, options, strutils, chronos, chronicles, chronicles/topics_registry,
|
||||
stew/byteutils, confutils,
|
||||
stew/byteutils, stew/shims/net, confutils,
|
||||
eth/keys, eth/trie/db, eth/net/nat,
|
||||
eth/p2p/discoveryv5/[protocol, discovery_db, enr, node]
|
||||
|
||||
|
@ -76,7 +76,7 @@ proc parseCmdArg*(T: type Node, p: TaintedString): T =
|
|||
proc completeCmdArg*(T: type Node, val: TaintedString): seq[string] =
|
||||
return @[]
|
||||
|
||||
proc setupNat(conf: DiscoveryConf): tuple[ip: Option[IpAddress],
|
||||
proc setupNat(conf: DiscoveryConf): tuple[ip: Option[ValidIpAddress],
|
||||
tcpPort: Port,
|
||||
udpPort: Port] {.gcsafe.} =
|
||||
# defaults
|
||||
|
@ -96,15 +96,16 @@ proc setupNat(conf: DiscoveryConf): tuple[ip: Option[IpAddress],
|
|||
else:
|
||||
if conf.nat.startsWith("extip:") and isIpAddress(conf.nat[6..^1]):
|
||||
# any required port redirection is assumed to be done by hand
|
||||
result.ip = some(parseIpAddress(conf.nat[6..^1]))
|
||||
result.ip = some(ValidIpAddress.init(conf.nat[6..^1]))
|
||||
nat = NatNone
|
||||
else:
|
||||
error "not a valid NAT mechanism, nor a valid IP address", value = conf.nat
|
||||
quit(QuitFailure)
|
||||
|
||||
if nat != NatNone:
|
||||
result.ip = getExternalIP(nat)
|
||||
if result.ip.isSome:
|
||||
let extIp = getExternalIP(nat)
|
||||
if extIP.isSome:
|
||||
result.ip = some(ValidIpAddress.init extIp.get)
|
||||
let extPorts = ({.gcsafe.}:
|
||||
redirectPorts(tcpPort = result.tcpPort,
|
||||
udpPort = result.udpPort,
|
||||
|
|
|
@ -217,6 +217,14 @@ proc toTypedRecord*(r: Record): EnrResult[TypedRecord] =
|
|||
else:
|
||||
err("Record without id field")
|
||||
|
||||
proc contains*(r: Record, fp: (string, seq[byte])): bool =
|
||||
# TODO: use FieldPair for this, but that is a bit cumbersome. Perhaps the
|
||||
# `get` call can be improved to make this easier.
|
||||
let field = r.tryGet(fp[0], seq[byte])
|
||||
if field.isSome():
|
||||
if field.get() == fp[1]:
|
||||
return true
|
||||
|
||||
proc verifySignatureV4(r: Record, sigData: openarray[byte], content: seq[byte]):
|
||||
bool =
|
||||
let publicKey = r.get(PublicKey)
|
||||
|
|
|
@ -31,10 +31,7 @@ proc newNode*(r: Record): Result[Node, cstring] =
|
|||
|
||||
let tr = ? r.toTypedRecord()
|
||||
if tr.ip.isSome() and tr.udp.isSome():
|
||||
let
|
||||
ip = ValidIpAddress.init(
|
||||
IpAddress(family: IpAddressFamily.IPv4, address_v4: tr.ip.get()))
|
||||
a = Address(ip: ip, port: Port(tr.udp.get()))
|
||||
let a = Address(ip: ipv4(tr.ip.get()), port: Port(tr.udp.get()))
|
||||
|
||||
ok(Node(id: pk.get().toNodeId(), pubkey: pk.get() , record: r,
|
||||
address: some(a)))
|
||||
|
|
|
@ -145,8 +145,21 @@ proc addNode*(d: Protocol, enr: EnrUri): bool =
|
|||
proc getNode*(d: Protocol, id: NodeId): Option[Node] =
|
||||
d.routingTable.getNode(id)
|
||||
|
||||
proc randomNodes*(d: Protocol, count: int): seq[Node] =
|
||||
d.routingTable.randomNodes(count)
|
||||
proc randomNodes*(d: Protocol, maxAmount: int): seq[Node] =
|
||||
## Get a `maxAmount` of random nodes from the local routing table.
|
||||
d.routingTable.randomNodes(maxAmount)
|
||||
|
||||
proc randomNodes*(d: Protocol, maxAmount: int,
|
||||
pred: proc(x: Node): bool {.gcsafe, noSideEffect.}): seq[Node] =
|
||||
## Get a `maxAmount` of random nodes from the local routing table with the
|
||||
## `pred` predicate function applied as filter on the nodes selected.
|
||||
d.routingTable.randomNodes(maxAmount, pred)
|
||||
|
||||
proc randomNodes*(d: Protocol, maxAmount: int,
|
||||
enrField: (string, seq[byte])): seq[Node] =
|
||||
## Get a `maxAmount` of random nodes from the local routing table. The
|
||||
## the nodes selected are filtered by provided `enrField`.
|
||||
d.randomNodes(maxAmount, proc(x: Node): bool = x.record.contains(enrField))
|
||||
|
||||
proc neighbours*(d: Protocol, id: NodeId, k: int = BUCKET_SIZE): seq[Node] =
|
||||
d.routingTable.neighbours(id, k)
|
||||
|
@ -598,6 +611,8 @@ proc lookup*(d: Protocol, target: NodeId): Future[seq[Node]]
|
|||
|
||||
proc lookupRandom*(d: Protocol): Future[DiscResult[seq[Node]]]
|
||||
{.async, raises:[Exception, Defect].} =
|
||||
## Perform a lookup for a random target, return the closest n nodes to the
|
||||
## target. Maximum value for n is `BUCKET_SIZE`.
|
||||
var id: NodeId
|
||||
if randomBytes(addr id, sizeof(id)) != sizeof(id):
|
||||
return err("Could not randomize bytes")
|
||||
|
|
|
@ -248,25 +248,28 @@ proc nodeToRevalidate*(r: RoutingTable): Node =
|
|||
if b.len > 0:
|
||||
return b.nodes[^1]
|
||||
|
||||
proc randomNodes*(r: RoutingTable, count: int): seq[Node] =
|
||||
var count = count
|
||||
proc randomNodes*(r: RoutingTable, maxAmount: int,
|
||||
pred: proc(x: Node): bool {.gcsafe, noSideEffect.} = nil): seq[Node] =
|
||||
var maxAmount = maxAmount
|
||||
let sz = r.len
|
||||
if count > sz:
|
||||
debug "Looking for peers", requested = count, present = sz
|
||||
count = sz
|
||||
if maxAmount > sz:
|
||||
debug "Less peers in routing table than maximum requested",
|
||||
requested = maxAmount, present = sz
|
||||
maxAmount = sz
|
||||
|
||||
result = newSeqOfCap[Node](count)
|
||||
result = newSeqOfCap[Node](maxAmount)
|
||||
var seen = initHashSet[Node]()
|
||||
|
||||
# This is a rather inneficient way of randomizing nodes from all buckets, but even if we
|
||||
# iterate over all nodes in the routing table, the time it takes would still be
|
||||
# insignificant compared to the time it takes for the network roundtrips when connecting
|
||||
# to nodes.
|
||||
while len(seen) < count:
|
||||
while len(seen) < maxAmount:
|
||||
# TODO: Is it important to get a better random source for these sample calls?
|
||||
let bucket = sample(r.buckets)
|
||||
if bucket.nodes.len != 0:
|
||||
let node = sample(bucket.nodes)
|
||||
if node notin seen:
|
||||
result.add(node)
|
||||
seen.incl(node)
|
||||
if pred.isNil() or node.pred:
|
||||
result.add(node)
|
||||
|
|
|
@ -9,13 +9,15 @@ proc localAddress*(port: int): Address =
|
|||
Address(ip: ValidIpAddress.init("127.0.0.1"), port: Port(port))
|
||||
|
||||
proc initDiscoveryNode*(privKey: PrivateKey, address: Address,
|
||||
bootstrapRecords: openarray[Record] = []):
|
||||
bootstrapRecords: openarray[Record] = [],
|
||||
localEnrFields: openarray[FieldPair] = []):
|
||||
discv5_protocol.Protocol =
|
||||
var db = DiscoveryDB.init(newMemoryDB())
|
||||
result = newProtocol(privKey, db,
|
||||
some(address.ip),
|
||||
address.port, address.port,
|
||||
bootstrapRecords = bootstrapRecords)
|
||||
bootstrapRecords = bootstrapRecords,
|
||||
localEnrFields = localEnrFields)
|
||||
|
||||
result.open()
|
||||
|
||||
|
@ -35,10 +37,11 @@ proc randomPacket(tag: PacketTag): seq[byte] =
|
|||
result.add(rlp.encode(authTag))
|
||||
result.add(msg)
|
||||
|
||||
proc generateNode(privKey = PrivateKey.random()[], port: int = 20302): Node =
|
||||
proc generateNode(privKey = PrivateKey.random()[], port: int = 20302,
|
||||
localEnrFields: openarray[FieldPair] = []): Node =
|
||||
let port = Port(port)
|
||||
let enr = enr.Record.init(1, privKey, some(ValidIpAddress.init("127.0.0.1")),
|
||||
port, port).expect("Properly intialized private key")
|
||||
port, port, localEnrFields).expect("Properly intialized private key")
|
||||
result = newNode(enr).expect("Properly initialized node")
|
||||
|
||||
proc nodeAtDistance(n: Node, d: uint32): Node =
|
||||
|
@ -387,3 +390,25 @@ suite "Discovery v5 Tests":
|
|||
|
||||
await mainNode.closeWait()
|
||||
await lookupNode.closeWait()
|
||||
|
||||
asyncTest "Random nodes with enr field filter":
|
||||
let
|
||||
lookupNode = initDiscoveryNode(PrivateKey.random()[], localAddress(20301))
|
||||
targetFieldPair = toFieldPair("test", @[byte 1,2,3,4])
|
||||
targetNode = generateNode(localEnrFields = [targetFieldPair])
|
||||
otherFieldPair = toFieldPair("test", @[byte 1,2,3,4,5])
|
||||
otherNode = generateNode(localEnrFields = [otherFieldPair])
|
||||
anotherNode = generateNode()
|
||||
|
||||
check:
|
||||
lookupNode.addNode(targetNode)
|
||||
lookupNode.addNode(otherNode)
|
||||
lookupNode.addNode(anotherNode)
|
||||
|
||||
let discovered = lookupNode.randomNodes(10)
|
||||
check discovered.len == 3
|
||||
let discoveredFiltered = lookupNode.randomNodes(10,
|
||||
("test", @[byte 1,2,3,4]))
|
||||
check discoveredFiltered.len == 1 and discoveredFiltered.contains(targetNode)
|
||||
|
||||
await lookupNode.closeWait()
|
||||
|
|
Loading…
Reference in New Issue