mirror of https://github.com/status-im/nim-eth.git
Add predicate filter option for randomNodes (#251)
* Add predicate filter option for randomNodes * Further ValidIpAddress fixes * Add gcsafe/noSideEffect and add test case
This commit is contained in:
parent
be9a87848e
commit
225a9ad41c
|
@ -1,6 +1,6 @@
|
||||||
import
|
import
|
||||||
sequtils, options, strutils, chronos, chronicles, chronicles/topics_registry,
|
sequtils, options, strutils, chronos, chronicles, chronicles/topics_registry,
|
||||||
stew/byteutils, confutils,
|
stew/byteutils, stew/shims/net, confutils,
|
||||||
eth/keys, eth/trie/db, eth/net/nat,
|
eth/keys, eth/trie/db, eth/net/nat,
|
||||||
eth/p2p/discoveryv5/[protocol, discovery_db, enr, node]
|
eth/p2p/discoveryv5/[protocol, discovery_db, enr, node]
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ proc parseCmdArg*(T: type Node, p: TaintedString): T =
|
||||||
proc completeCmdArg*(T: type Node, val: TaintedString): seq[string] =
|
proc completeCmdArg*(T: type Node, val: TaintedString): seq[string] =
|
||||||
return @[]
|
return @[]
|
||||||
|
|
||||||
proc setupNat(conf: DiscoveryConf): tuple[ip: Option[IpAddress],
|
proc setupNat(conf: DiscoveryConf): tuple[ip: Option[ValidIpAddress],
|
||||||
tcpPort: Port,
|
tcpPort: Port,
|
||||||
udpPort: Port] {.gcsafe.} =
|
udpPort: Port] {.gcsafe.} =
|
||||||
# defaults
|
# defaults
|
||||||
|
@ -96,15 +96,16 @@ proc setupNat(conf: DiscoveryConf): tuple[ip: Option[IpAddress],
|
||||||
else:
|
else:
|
||||||
if conf.nat.startsWith("extip:") and isIpAddress(conf.nat[6..^1]):
|
if conf.nat.startsWith("extip:") and isIpAddress(conf.nat[6..^1]):
|
||||||
# any required port redirection is assumed to be done by hand
|
# any required port redirection is assumed to be done by hand
|
||||||
result.ip = some(parseIpAddress(conf.nat[6..^1]))
|
result.ip = some(ValidIpAddress.init(conf.nat[6..^1]))
|
||||||
nat = NatNone
|
nat = NatNone
|
||||||
else:
|
else:
|
||||||
error "not a valid NAT mechanism, nor a valid IP address", value = conf.nat
|
error "not a valid NAT mechanism, nor a valid IP address", value = conf.nat
|
||||||
quit(QuitFailure)
|
quit(QuitFailure)
|
||||||
|
|
||||||
if nat != NatNone:
|
if nat != NatNone:
|
||||||
result.ip = getExternalIP(nat)
|
let extIp = getExternalIP(nat)
|
||||||
if result.ip.isSome:
|
if extIP.isSome:
|
||||||
|
result.ip = some(ValidIpAddress.init extIp.get)
|
||||||
let extPorts = ({.gcsafe.}:
|
let extPorts = ({.gcsafe.}:
|
||||||
redirectPorts(tcpPort = result.tcpPort,
|
redirectPorts(tcpPort = result.tcpPort,
|
||||||
udpPort = result.udpPort,
|
udpPort = result.udpPort,
|
||||||
|
|
|
@ -217,6 +217,14 @@ proc toTypedRecord*(r: Record): EnrResult[TypedRecord] =
|
||||||
else:
|
else:
|
||||||
err("Record without id field")
|
err("Record without id field")
|
||||||
|
|
||||||
|
proc contains*(r: Record, fp: (string, seq[byte])): bool =
|
||||||
|
# TODO: use FieldPair for this, but that is a bit cumbersome. Perhaps the
|
||||||
|
# `get` call can be improved to make this easier.
|
||||||
|
let field = r.tryGet(fp[0], seq[byte])
|
||||||
|
if field.isSome():
|
||||||
|
if field.get() == fp[1]:
|
||||||
|
return true
|
||||||
|
|
||||||
proc verifySignatureV4(r: Record, sigData: openarray[byte], content: seq[byte]):
|
proc verifySignatureV4(r: Record, sigData: openarray[byte], content: seq[byte]):
|
||||||
bool =
|
bool =
|
||||||
let publicKey = r.get(PublicKey)
|
let publicKey = r.get(PublicKey)
|
||||||
|
|
|
@ -31,10 +31,7 @@ proc newNode*(r: Record): Result[Node, cstring] =
|
||||||
|
|
||||||
let tr = ? r.toTypedRecord()
|
let tr = ? r.toTypedRecord()
|
||||||
if tr.ip.isSome() and tr.udp.isSome():
|
if tr.ip.isSome() and tr.udp.isSome():
|
||||||
let
|
let a = Address(ip: ipv4(tr.ip.get()), port: Port(tr.udp.get()))
|
||||||
ip = ValidIpAddress.init(
|
|
||||||
IpAddress(family: IpAddressFamily.IPv4, address_v4: tr.ip.get()))
|
|
||||||
a = Address(ip: ip, port: Port(tr.udp.get()))
|
|
||||||
|
|
||||||
ok(Node(id: pk.get().toNodeId(), pubkey: pk.get() , record: r,
|
ok(Node(id: pk.get().toNodeId(), pubkey: pk.get() , record: r,
|
||||||
address: some(a)))
|
address: some(a)))
|
||||||
|
|
|
@ -145,8 +145,21 @@ proc addNode*(d: Protocol, enr: EnrUri): bool =
|
||||||
proc getNode*(d: Protocol, id: NodeId): Option[Node] =
|
proc getNode*(d: Protocol, id: NodeId): Option[Node] =
|
||||||
d.routingTable.getNode(id)
|
d.routingTable.getNode(id)
|
||||||
|
|
||||||
proc randomNodes*(d: Protocol, count: int): seq[Node] =
|
proc randomNodes*(d: Protocol, maxAmount: int): seq[Node] =
|
||||||
d.routingTable.randomNodes(count)
|
## Get a `maxAmount` of random nodes from the local routing table.
|
||||||
|
d.routingTable.randomNodes(maxAmount)
|
||||||
|
|
||||||
|
proc randomNodes*(d: Protocol, maxAmount: int,
|
||||||
|
pred: proc(x: Node): bool {.gcsafe, noSideEffect.}): seq[Node] =
|
||||||
|
## Get a `maxAmount` of random nodes from the local routing table with the
|
||||||
|
## `pred` predicate function applied as filter on the nodes selected.
|
||||||
|
d.routingTable.randomNodes(maxAmount, pred)
|
||||||
|
|
||||||
|
proc randomNodes*(d: Protocol, maxAmount: int,
|
||||||
|
enrField: (string, seq[byte])): seq[Node] =
|
||||||
|
## Get a `maxAmount` of random nodes from the local routing table. The
|
||||||
|
## the nodes selected are filtered by provided `enrField`.
|
||||||
|
d.randomNodes(maxAmount, proc(x: Node): bool = x.record.contains(enrField))
|
||||||
|
|
||||||
proc neighbours*(d: Protocol, id: NodeId, k: int = BUCKET_SIZE): seq[Node] =
|
proc neighbours*(d: Protocol, id: NodeId, k: int = BUCKET_SIZE): seq[Node] =
|
||||||
d.routingTable.neighbours(id, k)
|
d.routingTable.neighbours(id, k)
|
||||||
|
@ -598,6 +611,8 @@ proc lookup*(d: Protocol, target: NodeId): Future[seq[Node]]
|
||||||
|
|
||||||
proc lookupRandom*(d: Protocol): Future[DiscResult[seq[Node]]]
|
proc lookupRandom*(d: Protocol): Future[DiscResult[seq[Node]]]
|
||||||
{.async, raises:[Exception, Defect].} =
|
{.async, raises:[Exception, Defect].} =
|
||||||
|
## Perform a lookup for a random target, return the closest n nodes to the
|
||||||
|
## target. Maximum value for n is `BUCKET_SIZE`.
|
||||||
var id: NodeId
|
var id: NodeId
|
||||||
if randomBytes(addr id, sizeof(id)) != sizeof(id):
|
if randomBytes(addr id, sizeof(id)) != sizeof(id):
|
||||||
return err("Could not randomize bytes")
|
return err("Could not randomize bytes")
|
||||||
|
|
|
@ -248,25 +248,28 @@ proc nodeToRevalidate*(r: RoutingTable): Node =
|
||||||
if b.len > 0:
|
if b.len > 0:
|
||||||
return b.nodes[^1]
|
return b.nodes[^1]
|
||||||
|
|
||||||
proc randomNodes*(r: RoutingTable, count: int): seq[Node] =
|
proc randomNodes*(r: RoutingTable, maxAmount: int,
|
||||||
var count = count
|
pred: proc(x: Node): bool {.gcsafe, noSideEffect.} = nil): seq[Node] =
|
||||||
|
var maxAmount = maxAmount
|
||||||
let sz = r.len
|
let sz = r.len
|
||||||
if count > sz:
|
if maxAmount > sz:
|
||||||
debug "Looking for peers", requested = count, present = sz
|
debug "Less peers in routing table than maximum requested",
|
||||||
count = sz
|
requested = maxAmount, present = sz
|
||||||
|
maxAmount = sz
|
||||||
|
|
||||||
result = newSeqOfCap[Node](count)
|
result = newSeqOfCap[Node](maxAmount)
|
||||||
var seen = initHashSet[Node]()
|
var seen = initHashSet[Node]()
|
||||||
|
|
||||||
# This is a rather inneficient way of randomizing nodes from all buckets, but even if we
|
# This is a rather inneficient way of randomizing nodes from all buckets, but even if we
|
||||||
# iterate over all nodes in the routing table, the time it takes would still be
|
# iterate over all nodes in the routing table, the time it takes would still be
|
||||||
# insignificant compared to the time it takes for the network roundtrips when connecting
|
# insignificant compared to the time it takes for the network roundtrips when connecting
|
||||||
# to nodes.
|
# to nodes.
|
||||||
while len(seen) < count:
|
while len(seen) < maxAmount:
|
||||||
# TODO: Is it important to get a better random source for these sample calls?
|
# TODO: Is it important to get a better random source for these sample calls?
|
||||||
let bucket = sample(r.buckets)
|
let bucket = sample(r.buckets)
|
||||||
if bucket.nodes.len != 0:
|
if bucket.nodes.len != 0:
|
||||||
let node = sample(bucket.nodes)
|
let node = sample(bucket.nodes)
|
||||||
if node notin seen:
|
if node notin seen:
|
||||||
result.add(node)
|
|
||||||
seen.incl(node)
|
seen.incl(node)
|
||||||
|
if pred.isNil() or node.pred:
|
||||||
|
result.add(node)
|
||||||
|
|
|
@ -9,13 +9,15 @@ proc localAddress*(port: int): Address =
|
||||||
Address(ip: ValidIpAddress.init("127.0.0.1"), port: Port(port))
|
Address(ip: ValidIpAddress.init("127.0.0.1"), port: Port(port))
|
||||||
|
|
||||||
proc initDiscoveryNode*(privKey: PrivateKey, address: Address,
|
proc initDiscoveryNode*(privKey: PrivateKey, address: Address,
|
||||||
bootstrapRecords: openarray[Record] = []):
|
bootstrapRecords: openarray[Record] = [],
|
||||||
|
localEnrFields: openarray[FieldPair] = []):
|
||||||
discv5_protocol.Protocol =
|
discv5_protocol.Protocol =
|
||||||
var db = DiscoveryDB.init(newMemoryDB())
|
var db = DiscoveryDB.init(newMemoryDB())
|
||||||
result = newProtocol(privKey, db,
|
result = newProtocol(privKey, db,
|
||||||
some(address.ip),
|
some(address.ip),
|
||||||
address.port, address.port,
|
address.port, address.port,
|
||||||
bootstrapRecords = bootstrapRecords)
|
bootstrapRecords = bootstrapRecords,
|
||||||
|
localEnrFields = localEnrFields)
|
||||||
|
|
||||||
result.open()
|
result.open()
|
||||||
|
|
||||||
|
@ -35,10 +37,11 @@ proc randomPacket(tag: PacketTag): seq[byte] =
|
||||||
result.add(rlp.encode(authTag))
|
result.add(rlp.encode(authTag))
|
||||||
result.add(msg)
|
result.add(msg)
|
||||||
|
|
||||||
proc generateNode(privKey = PrivateKey.random()[], port: int = 20302): Node =
|
proc generateNode(privKey = PrivateKey.random()[], port: int = 20302,
|
||||||
|
localEnrFields: openarray[FieldPair] = []): Node =
|
||||||
let port = Port(port)
|
let port = Port(port)
|
||||||
let enr = enr.Record.init(1, privKey, some(ValidIpAddress.init("127.0.0.1")),
|
let enr = enr.Record.init(1, privKey, some(ValidIpAddress.init("127.0.0.1")),
|
||||||
port, port).expect("Properly intialized private key")
|
port, port, localEnrFields).expect("Properly intialized private key")
|
||||||
result = newNode(enr).expect("Properly initialized node")
|
result = newNode(enr).expect("Properly initialized node")
|
||||||
|
|
||||||
proc nodeAtDistance(n: Node, d: uint32): Node =
|
proc nodeAtDistance(n: Node, d: uint32): Node =
|
||||||
|
@ -387,3 +390,25 @@ suite "Discovery v5 Tests":
|
||||||
|
|
||||||
await mainNode.closeWait()
|
await mainNode.closeWait()
|
||||||
await lookupNode.closeWait()
|
await lookupNode.closeWait()
|
||||||
|
|
||||||
|
asyncTest "Random nodes with enr field filter":
|
||||||
|
let
|
||||||
|
lookupNode = initDiscoveryNode(PrivateKey.random()[], localAddress(20301))
|
||||||
|
targetFieldPair = toFieldPair("test", @[byte 1,2,3,4])
|
||||||
|
targetNode = generateNode(localEnrFields = [targetFieldPair])
|
||||||
|
otherFieldPair = toFieldPair("test", @[byte 1,2,3,4,5])
|
||||||
|
otherNode = generateNode(localEnrFields = [otherFieldPair])
|
||||||
|
anotherNode = generateNode()
|
||||||
|
|
||||||
|
check:
|
||||||
|
lookupNode.addNode(targetNode)
|
||||||
|
lookupNode.addNode(otherNode)
|
||||||
|
lookupNode.addNode(anotherNode)
|
||||||
|
|
||||||
|
let discovered = lookupNode.randomNodes(10)
|
||||||
|
check discovered.len == 3
|
||||||
|
let discoveredFiltered = lookupNode.randomNodes(10,
|
||||||
|
("test", @[byte 1,2,3,4]))
|
||||||
|
check discoveredFiltered.len == 1 and discoveredFiltered.contains(targetNode)
|
||||||
|
|
||||||
|
await lookupNode.closeWait()
|
||||||
|
|
Loading…
Reference in New Issue