Discv5 Protocol: Add support for banning nodes (#769)

* Add banned nodes to routing table.

* Filter out banned nodes in lookups and cleanup expired bans in refreshLoop.

* Don't respond to messages from banned nodes.

* Prevent sending messages to banned nodes.
This commit is contained in:
bhartnett 2025-01-30 19:28:10 +08:00 committed by GitHub
parent e589cc0288
commit c640d3c444
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 190 additions and 18 deletions

View File

@ -124,6 +124,9 @@ const
defaultResponseTimeout* = 4.seconds ## timeout for the response of a request-response defaultResponseTimeout* = 4.seconds ## timeout for the response of a request-response
## call ## call
## Ban durations for banned nodes in the routing table
NodeBanDurationInvalidResponse = 15.minutes
type type
OptAddress* = object OptAddress* = object
ip*: Opt[IpAddress] ip*: Opt[IpAddress]
@ -142,6 +145,7 @@ type
bindAddress: OptAddress ## UDP binding address bindAddress: OptAddress ## UDP binding address
pendingRequests: Table[AESGCMNonce, PendingRequest] pendingRequests: Table[AESGCMNonce, PendingRequest]
routingTable*: RoutingTable routingTable*: RoutingTable
banNodes: bool
codec*: Codec codec*: Codec
awaitedMessages: Table[(NodeId, RequestId), Future[Opt[Message]]] awaitedMessages: Table[(NodeId, RequestId), Future[Opt[Message]]]
refreshLoop: Future[void] refreshLoop: Future[void]
@ -157,6 +161,7 @@ type
responseTimeout: Duration responseTimeout: Duration
rng*: ref HmacDrbgContext rng*: ref HmacDrbgContext
PendingRequest = object PendingRequest = object
node: Node node: Node
message: seq[byte] message: seq[byte]
@ -192,10 +197,13 @@ proc addNode*(d: Protocol, node: Node): bool =
## ##
## Returns true only when `Node` was added as a new entry to a bucket in the ## Returns true only when `Node` was added as a new entry to a bucket in the
## routing table. ## routing table.
if d.routingTable.addNode(node) == Added: let r = d.routingTable.addNode(node)
if r == Added:
return true return true
else:
return false if r == Banned:
debug "Banned node not added to routing table", nodeId = node.id
return false
proc addNode*(d: Protocol, r: Record): bool = proc addNode*(d: Protocol, r: Record): bool =
## Add `Node` from a `Record` to discovery routing table. ## Add `Node` from a `Record` to discovery routing table.
@ -429,6 +437,30 @@ proc sendWhoareyou(d: Protocol, toId: NodeId, a: Address,
else: else:
debug "Node with this id already has ongoing handshake, ignoring packet" debug "Node with this id already has ongoing handshake, ignoring packet"
proc replaceNode(d: Protocol, n: Node) =
if n.record notin d.bootstrapRecords:
d.routingTable.replaceNode(n)
else:
# For now we never remove bootstrap nodes. It might make sense to actually
# do so and to retry them only in case we drop to a really low amount of
# peers in the routing table.
debug "Message request to bootstrap node failed", enr = toURI(n.record)
proc banNode*(d: Protocol, n: Node, banPeriod: chronos.Duration) =
if n.record notin d.bootstrapRecords:
if d.banNodes:
d.routingTable.banNode(n.id, banPeriod) # banNode also replaces the node
else:
d.routingTable.replaceNode(n)
else:
# For now we never remove bootstrap nodes. It might make sense to actually
# do so and to retry them only in case we drop to a really low amount of
# peers in the routing table.
debug "Message request to bootstrap node failed", enr = toURI(n.record)
proc isBanned*(d: Protocol, nodeId: NodeId): bool =
d.banNodes and d.routingTable.isBanned(nodeId)
proc receive*(d: Protocol, a: Address, packet: openArray[byte]) = proc receive*(d: Protocol, a: Address, packet: openArray[byte]) =
discv5_network_bytes.inc(packet.len.int64, labelValues = [$Direction.In]) discv5_network_bytes.inc(packet.len.int64, labelValues = [$Direction.In])
@ -437,6 +469,10 @@ proc receive*(d: Protocol, a: Address, packet: openArray[byte]) =
let packet = decoded[] let packet = decoded[]
case packet.flag case packet.flag
of OrdinaryMessage: of OrdinaryMessage:
if d.isBanned(packet.srcId):
trace "Ignoring received OrdinaryMessage from banned node", nodeId = packet.srcId
return
if packet.messageOpt.isSome(): if packet.messageOpt.isSome():
let message = packet.messageOpt.get() let message = packet.messageOpt.get()
trace "Received message packet", srcId = packet.srcId, address = a, trace "Received message packet", srcId = packet.srcId, address = a,
@ -464,6 +500,10 @@ proc receive*(d: Protocol, a: Address, packet: openArray[byte]) =
else: else:
debug "Timed out or unrequested whoareyou packet", address = a debug "Timed out or unrequested whoareyou packet", address = a
of HandshakeMessage: of HandshakeMessage:
if d.isBanned(packet.srcIdHs):
trace "Ignoring received HandshakeMessage from banned node", nodeId = packet.srcIdHs
return
trace "Received handshake message packet", srcId = packet.srcIdHs, trace "Received handshake message packet", srcId = packet.srcIdHs,
address = a, kind = packet.message.kind address = a, kind = packet.message.kind
d.handleMessage(packet.srcIdHs, a, packet.message, packet.node) d.handleMessage(packet.srcIdHs, a, packet.message, packet.node)
@ -494,14 +534,7 @@ proc processClient(transp: DatagramTransport, raddr: TransportAddress):
proto.receive(Address(ip: raddr.toIpAddress(), port: raddr.port), buf) proto.receive(Address(ip: raddr.toIpAddress(), port: raddr.port), buf)
proc replaceNode(d: Protocol, n: Node) =
if n.record notin d.bootstrapRecords:
d.routingTable.replaceNode(n)
else:
# For now we never remove bootstrap nodes. It might make sense to actually
# do so and to retry them only in case we drop to a really low amount of
# peers in the routing table.
debug "Message request to bootstrap node failed", enr = toURI(n.record)
# TODO: This could be improved to do the clean-up immediately in case a non # TODO: This could be improved to do the clean-up immediately in case a non
# whoareyou response does arrive, but we would need to store the AuthTag # whoareyou response does arrive, but we would need to store the AuthTag
@ -546,9 +579,11 @@ proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId):
break break
return ok(res) return ok(res)
else: else:
d.banNode(fromNode, NodeBanDurationInvalidResponse)
discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"]) discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"])
return err("Invalid response to find node message") return err("Invalid response to find node message")
else: else:
d.replaceNode(fromNode)
discovery_message_requests_outgoing.inc(labelValues = ["no_response"]) discovery_message_requests_outgoing.inc(labelValues = ["no_response"])
return err("Nodes message not received in time") return err("Nodes message not received in time")
@ -574,6 +609,10 @@ proc ping*(d: Protocol, toNode: Node):
## Send a discovery ping message. ## Send a discovery ping message.
## ##
## Returns the received pong message or an error. ## Returns the received pong message or an error.
if d.isBanned(toNode.id):
return err("toNode is banned")
let reqId = d.sendMessage(toNode, let reqId = d.sendMessage(toNode,
PingMessage(enrSeq: d.localNode.record.seqNum)) PingMessage(enrSeq: d.localNode.record.seqNum))
let resp = await d.waitMessage(toNode, reqId) let resp = await d.waitMessage(toNode, reqId)
@ -583,7 +622,7 @@ proc ping*(d: Protocol, toNode: Node):
d.routingTable.setJustSeen(toNode) d.routingTable.setJustSeen(toNode)
return ok(resp.get().pong) return ok(resp.get().pong)
else: else:
d.replaceNode(toNode) d.banNode(toNode, NodeBanDurationInvalidResponse)
discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"]) discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"])
return err("Invalid response to ping message") return err("Invalid response to ping message")
else: else:
@ -597,15 +636,18 @@ proc findNode*(d: Protocol, toNode: Node, distances: seq[uint16]):
## ##
## Returns the received nodes or an error. ## Returns the received nodes or an error.
## Received ENRs are already validated and converted to `Node`. ## Received ENRs are already validated and converted to `Node`.
if d.isBanned(toNode.id):
return err("toNode is banned")
let reqId = d.sendMessage(toNode, FindNodeMessage(distances: distances)) let reqId = d.sendMessage(toNode, FindNodeMessage(distances: distances))
let nodes = await d.waitNodes(toNode, reqId) let nodes = await d.waitNodes(toNode, reqId)
if nodes.isOk: if nodes.isOk:
let res = verifyNodesRecords(nodes.get(), toNode, findNodeResultLimit, distances) let res = verifyNodesRecords(nodes.get(), toNode, findNodeResultLimit, distances)
d.routingTable.setJustSeen(toNode) d.routingTable.setJustSeen(toNode)
return ok(res) return ok(res.filterIt(not d.isBanned(it.id)))
else: else:
d.replaceNode(toNode)
return err(nodes.error) return err(nodes.error)
proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]): proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]):
@ -613,6 +655,10 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]):
## Send a discovery talkreq message. ## Send a discovery talkreq message.
## ##
## Returns the received talkresp message or an error. ## Returns the received talkresp message or an error.
if d.isBanned(toNode.id):
return err("toNode is banned")
let reqId = d.sendMessage(toNode, let reqId = d.sendMessage(toNode,
TalkReqMessage(protocol: protocol, request: request)) TalkReqMessage(protocol: protocol, request: request))
let resp = await d.waitMessage(toNode, reqId) let resp = await d.waitMessage(toNode, reqId)
@ -622,7 +668,7 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]):
d.routingTable.setJustSeen(toNode) d.routingTable.setJustSeen(toNode)
return ok(resp.get().talkResp.response) return ok(resp.get().talkResp.response)
else: else:
d.replaceNode(toNode) d.banNode(toNode, NodeBanDurationInvalidResponse)
discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"]) discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"])
return err("Invalid response to talk request message") return err("Invalid response to talk request message")
else: else:
@ -797,6 +843,12 @@ proc resolve*(d: Protocol, id: NodeId): Future[Opt[Node]] {.async: (raises: [Can
if id == d.localNode.id: if id == d.localNode.id:
return Opt.some(d.localNode) return Opt.some(d.localNode)
# No point in trying to resolve a banned node because it won't exist in the
# routing table and it will be filtered out of any respones in the lookup call
if d.isBanned(id):
debug "Not resolving banned node", nodeId = id
return Opt.none(Node)
let node = d.getNode(id) let node = d.getNode(id)
if node.isSome(): if node.isSome():
let request = await d.findNode(node.get(), @[0'u16]) let request = await d.findNode(node.get(), @[0'u16])
@ -882,6 +934,9 @@ proc refreshLoop(d: Protocol) {.async: (raises: []).} =
trace "Discovered nodes in random target query", nodes = randomQuery.len trace "Discovered nodes in random target query", nodes = randomQuery.len
debug "Total nodes in discv5 routing table", total = d.routingTable.len() debug "Total nodes in discv5 routing table", total = d.routingTable.len()
# Remove the expired bans from routing table to limit memory usage
d.routingTable.cleanupExpiredBans()
await sleepAsync(refreshInterval) await sleepAsync(refreshInterval)
except CancelledError: except CancelledError:
trace "refreshLoop canceled" trace "refreshLoop canceled"
@ -985,6 +1040,7 @@ proc newProtocol*(
bindPort: Port, bindPort: Port,
bindIp = IPv4_any(), bindIp = IPv4_any(),
enrAutoUpdate = false, enrAutoUpdate = false,
banNodes = false,
config = defaultDiscoveryConfig, config = defaultDiscoveryConfig,
rng = newRng()): rng = newRng()):
Protocol = Protocol =
@ -1034,6 +1090,7 @@ proc newProtocol*(
enrAutoUpdate: enrAutoUpdate, enrAutoUpdate: enrAutoUpdate,
routingTable: RoutingTable.init( routingTable: RoutingTable.init(
node, config.bitsPerHop, config.tableIpLimits, rng), node, config.bitsPerHop, config.tableIpLimits, rng),
banNodes: banNodes,
handshakeTimeout: config.handshakeTimeout, handshakeTimeout: config.handshakeTimeout,
responseTimeout: config.responseTimeout, responseTimeout: config.responseTimeout,
rng: rng) rng: rng)

View File

@ -195,7 +195,7 @@ func ipLimitDec(r: var RoutingTable, b: KBucket, n: Node) =
r.ipLimits.dec(ip) r.ipLimits.dec(ip)
func getNode*(r: RoutingTable, id: NodeId): Opt[Node] func getNode*(r: RoutingTable, id: NodeId): Opt[Node]
proc replaceNode*(r: var RoutingTable, n: Node) proc replaceNode*(r: var RoutingTable, n: Node) {.gcsafe.}
proc banNode*(r: var RoutingTable, nodeId: NodeId, period: chronos.Duration) = proc banNode*(r: var RoutingTable, nodeId: NodeId, period: chronos.Duration) =
## Ban a node from the routing table for the given period. The node is removed ## Ban a node from the routing table for the given period. The node is removed

View File

@ -22,7 +22,8 @@ proc initDiscoveryNode*(
address: Address, address: Address,
bootstrapRecords: openArray[Record] = [], bootstrapRecords: openArray[Record] = [],
localEnrFields: openArray[(string, seq[byte])] = [], localEnrFields: openArray[(string, seq[byte])] = [],
previousRecord = Opt.none(enr.Record)): previousRecord = Opt.none(enr.Record),
banNodes = false):
discv5_protocol.Protocol = discv5_protocol.Protocol =
# set bucketIpLimit to allow bucket split # set bucketIpLimit to allow bucket split
let config = DiscoveryConfig.init(1000, 24, 5) let config = DiscoveryConfig.init(1000, 24, 5)
@ -36,7 +37,8 @@ proc initDiscoveryNode*(
localEnrFields = localEnrFields, localEnrFields = localEnrFields,
previousRecord = previousRecord, previousRecord = previousRecord,
config = config, config = config,
rng = rng) rng = rng,
banNodes = banNodes)
protocol.open() protocol.open()

View File

@ -926,3 +926,116 @@ suite "Discovery v5 Tests":
await node1.closeWait() await node1.closeWait()
await node2.closeWait() await node2.closeWait()
asyncTest "Banned nodes are removed and cannot be added":
let
node = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302), banNodes = true)
targetNode = generateNode(PrivateKey.random(rng[]))
# add the node
check:
node.addNode(targetNode) == true
node.getNode(targetNode.id).isSome()
# banning the node should remove it from the routing table
node.banNode(targetNode, 1.minutes)
check node.getNode(targetNode.id).isNone()
# cannot add a banned node
check:
node.addNode(targetNode) == false
node.getNode(targetNode.id).isNone()
await node.closeWait()
asyncTest "FindNode filters out banned nodes":
let
mainNode = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20301),
banNodes = true)
testNode = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302),
@[mainNode.localNode.record], banNodes = true)
# Generate 100 random nodes and add to our main node's routing table
for i in 0 ..< 100:
discard mainNode.addSeenNode(generateNode(PrivateKey.random(rng[])))
let
neighbours = mainNode.neighbours(mainNode.localNode.id)
closest = neighbours[0]
closestDistance = logDistance(closest.id, mainNode.localNode.id)
block:
# the closest node is returned
let discovered = await testNode.findNode(mainNode.localNode, @[closestDistance])
check discovered.isOk
check closest in discovered[]
# ban the closest node
mainNode.banNode(closest, 1.minutes)
block:
# the banned node is not returned
let discovered = await testNode.findNode(mainNode.localNode, @[closestDistance])
check discovered.isOk
check closest notin discovered[]
await mainNode.closeWait()
await testNode.closeWait()
asyncTest "Cannot send messages to banned nodes":
let
node1 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302),
banNodes = true)
node2 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20301),
banNodes = true)
# ban node2 in node1's routing table
node1.banNode(node2.localNode, 1.minutes)
block:
let pong = await node1.ping(node2.localNode)
check:
pong.isErr()
pong.error() == "toNode is banned"
block:
let nodes = await node1.findNode(node2.localNode, @[0.uint16])
check:
nodes.isErr()
nodes.error() == "toNode is banned"
block:
let node = await node1.resolve(node2.localNode.id)
check node.isNone()
await node2.closeWait()
await node1.closeWait()
asyncTest "Ignore messages from banned nodes":
let
node1 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302),
banNodes = true)
node2 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20301),
banNodes = true)
# ban node1 in node2's routing table
node2.banNode(node1.localNode, 1.minutes)
block:
let pong = await node1.ping(node2.localNode)
check:
pong.isErr()
pong.error() == "Pong message not received in time"
block:
let nodes = await node1.findNode(node2.localNode, @[0.uint16])
check:
nodes.isErr()
nodes.error() == "Nodes message not received in time"
block:
let node = await node1.resolve(node2.localNode.id)
check node.isNone()
await node2.closeWait()
await node1.closeWait()