mirror of
https://github.com/status-im/nimbus-eth1.git
synced 2025-02-24 01:38:33 +00:00
Add validation handler to Portal wire protocol (#1055)
- Validation handler allows for network specific validation of the content - Added also pruning of the allowed connections on PortalStream
This commit is contained in:
parent
82aab2a404
commit
ed38ed494f
@ -55,14 +55,14 @@ func getEncodedKeyForContent(
|
||||
return encodeKey(contentKey)
|
||||
|
||||
proc validateHeaderBytes*(
|
||||
bytes: seq[byte], hash: BlockHash): Option[BlockHeader] =
|
||||
bytes: openArray[byte], hash: BlockHash): Option[BlockHeader] =
|
||||
try:
|
||||
var rlp = rlpFromBytes(bytes)
|
||||
|
||||
let blockHeader = rlp.read(BlockHeader)
|
||||
|
||||
if not (blockHeader.blockHash() == hash):
|
||||
# TODO: Header with different hash than expecte, maybe we should punish
|
||||
# TODO: Header with different hash than expected, maybe we should punish
|
||||
# peer which sent us this ?
|
||||
return none(BlockHeader)
|
||||
|
||||
@ -73,7 +73,7 @@ proc validateHeaderBytes*(
|
||||
return none(BlockHeader)
|
||||
|
||||
proc validateBodyBytes*(
|
||||
bytes: seq[byte], txRoot: KeccakHash, ommersHash: KeccakHash):
|
||||
bytes: openArray[byte], txRoot: KeccakHash, ommersHash: KeccakHash):
|
||||
Option[BlockBody] =
|
||||
try:
|
||||
var rlp = rlpFromBytes(bytes)
|
||||
@ -136,7 +136,7 @@ proc getBlockHeader*(
|
||||
# Content is valid we can propagate it to interested peers
|
||||
h.portalProtocol.triggerPoke(
|
||||
headerContent.nodesInterestedInContent,
|
||||
keyEncoded,
|
||||
keyEncoded,
|
||||
headerContent.content
|
||||
)
|
||||
|
||||
@ -182,7 +182,7 @@ proc getBlock*(
|
||||
# body is valid, propagate it to interested peers
|
||||
h.portalProtocol.triggerPoke(
|
||||
bodyContent.nodesInterestedInContent,
|
||||
keyEncoded,
|
||||
keyEncoded,
|
||||
bodyContent.content
|
||||
)
|
||||
|
||||
@ -192,7 +192,23 @@ proc getBlock*(
|
||||
|
||||
return some[Block]((header, blockBody))
|
||||
|
||||
# TODO Add getRecepits call
|
||||
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
|
||||
let keyOpt = contentKey.decode()
|
||||
|
||||
if keyOpt.isNone():
|
||||
return false
|
||||
|
||||
let key = keyOpt.get()
|
||||
|
||||
case key.contentType:
|
||||
of blockHeader:
|
||||
validateHeaderBytes(content, key.blockHeaderKey.blockHash).isSome()
|
||||
of blockBody:
|
||||
true
|
||||
# TODO: Need to get the header from the db or the network for this. Or how
|
||||
# to deal with this?
|
||||
of receipts:
|
||||
true
|
||||
|
||||
proc new*(
|
||||
T: type HistoryNetwork,
|
||||
@ -202,7 +218,8 @@ proc new*(
|
||||
bootstrapRecords: openArray[Record] = [],
|
||||
portalConfig: PortalProtocolConfig = defaultPortalProtocolConfig): T =
|
||||
let portalProtocol = PortalProtocol.new(
|
||||
baseProtocol, historyProtocolId, contentDB, toContentIdHandler,
|
||||
baseProtocol, historyProtocolId, contentDB,
|
||||
toContentIdHandler, validateContent,
|
||||
dataRadius, bootstrapRecords,
|
||||
config = portalConfig)
|
||||
|
||||
|
@ -56,6 +56,9 @@ proc getContent*(n: StateNetwork, key: ContentKey):
|
||||
# domain types.
|
||||
return some(contentResult.content)
|
||||
|
||||
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
|
||||
true
|
||||
|
||||
proc new*(
|
||||
T: type StateNetwork,
|
||||
baseProtocol: protocol.Protocol,
|
||||
@ -64,7 +67,8 @@ proc new*(
|
||||
bootstrapRecords: openArray[Record] = [],
|
||||
portalConfig: PortalProtocolConfig = defaultPortalProtocolConfig): T =
|
||||
let portalProtocol = PortalProtocol.new(
|
||||
baseProtocol, stateProtocolId, contentDB, toContentIdHandler,
|
||||
baseProtocol, stateProtocolId, contentDB,
|
||||
toContentIdHandler, validateContent,
|
||||
dataRadius, bootstrapRecords, stateDistanceCalculator,
|
||||
config = portalConfig)
|
||||
|
||||
|
@ -117,10 +117,14 @@ type
|
||||
ToContentIdHandler* =
|
||||
proc(contentKey: ByteList): Option[ContentId] {.raises: [Defect], gcsafe.}
|
||||
|
||||
ContentValidationHandler* =
|
||||
proc(content: openArray[byte], contentKey: ByteList):
|
||||
bool {.raises: [Defect], gcsafe.}
|
||||
|
||||
PortalProtocolId* = array[2, byte]
|
||||
|
||||
RadiusCache* = LRUCache[NodeId, UInt256]
|
||||
|
||||
|
||||
ContentInfo* = object
|
||||
contentKey*: ByteList
|
||||
content*: seq[byte]
|
||||
@ -142,6 +146,7 @@ type
|
||||
baseProtocol*: protocol.Protocol
|
||||
contentDB*: ContentDB
|
||||
toContentId: ToContentIdHandler
|
||||
validateContent: ContentValidationHandler
|
||||
dataRadius*: UInt256
|
||||
bootstrapRecords*: seq[Record]
|
||||
lastLookup: chronos.Moment
|
||||
@ -173,7 +178,7 @@ type
|
||||
nodesInterestedInContent*: seq[Node]
|
||||
|
||||
proc init*(
|
||||
T: type ContentInfo,
|
||||
T: type ContentInfo,
|
||||
contentKey: ByteList,
|
||||
content: seq[byte]): T =
|
||||
ContentInfo(
|
||||
@ -182,11 +187,11 @@ proc init*(
|
||||
)
|
||||
|
||||
proc init*(
|
||||
T: type ContentLookupResult,
|
||||
content: seq[byte],
|
||||
T: type ContentLookupResult,
|
||||
content: seq[byte],
|
||||
nodesInterestedInContent: seq[Node]): T =
|
||||
ContentLookupResult(
|
||||
content: content,
|
||||
content: content,
|
||||
nodesInterestedInContent: nodesInterestedInContent
|
||||
)
|
||||
|
||||
@ -210,8 +215,8 @@ func neighbours*(p: PortalProtocol, id: NodeId, seenOnly = false): seq[Node] =
|
||||
|
||||
proc inRange(
|
||||
p: PortalProtocol,
|
||||
nodeId: NodeId,
|
||||
nodeRadius: Uint256,
|
||||
nodeId: NodeId,
|
||||
nodeRadius: Uint256,
|
||||
contentId: ContentId): bool =
|
||||
let distance = p.routingTable.distance(nodeId, contentId)
|
||||
distance <= nodeRadius
|
||||
@ -414,6 +419,7 @@ proc new*(T: type PortalProtocol,
|
||||
protocolId: PortalProtocolId,
|
||||
contentDB: ContentDB,
|
||||
toContentId: ToContentIdHandler,
|
||||
validateContent: ContentValidationHandler,
|
||||
dataRadius = UInt256.high(),
|
||||
bootstrapRecords: openArray[Record] = [],
|
||||
distanceCalculator: DistanceCalculator = XorDistanceCalculator,
|
||||
@ -429,6 +435,7 @@ proc new*(T: type PortalProtocol,
|
||||
baseProtocol: baseProtocol,
|
||||
contentDB: contentDB,
|
||||
toContentId: toContentId,
|
||||
validateContent: validateContent,
|
||||
dataRadius: dataRadius,
|
||||
bootstrapRecords: @bootstrapRecords,
|
||||
radiusCache: RadiusCache.init(256),
|
||||
@ -612,7 +619,7 @@ proc findContent*(p: PortalProtocol, dst: Node, contentKey: ByteList):
|
||||
else:
|
||||
return err("Content message returned invalid ENRs")
|
||||
|
||||
proc getContentKeys(o: OfferRequest): ContentKeysList =
|
||||
proc getContentKeys(o: OfferRequest): ContentKeysList =
|
||||
case o.kind
|
||||
of Direct:
|
||||
var contentKeys:ContentKeysList
|
||||
@ -640,7 +647,7 @@ proc offer(p: PortalProtocol, o: OfferRequest):
|
||||
## to many peers, and keeping it all in memory could exhaust node resources.
|
||||
## Main drawback is that content may be deleted from the node database
|
||||
## by the cleanup process before it will be transferred, so this way does not
|
||||
## guarantee content transfer
|
||||
## guarantee content transfer.
|
||||
let contentKeys = getContentKeys(o)
|
||||
|
||||
let acceptMessageResponse = await p.offerImpl(o.dst, contentKeys)
|
||||
@ -716,7 +723,7 @@ proc offer*(p: PortalProtocol, dst: Node, content: seq[ContentInfo]):
|
||||
Future[PortalResult[void]] {.async.} =
|
||||
if len(content) > contentKeysLimit:
|
||||
return err("Cannot offer more than 64 content items")
|
||||
|
||||
|
||||
let contentList = List[ContentInfo, contentKeysLimit].init(content)
|
||||
let req = OfferRequest(dst: dst, kind: Direct, contentList: contentList)
|
||||
let res = await p.offer(req)
|
||||
@ -759,24 +766,26 @@ proc processContent(
|
||||
{.gcsafe, raises: [Defect].} =
|
||||
let p = getUserData[PortalProtocol](stream)
|
||||
|
||||
# TODO: validate content
|
||||
# - check amount of content items according to ContentKeysList
|
||||
# - History Network specific: each content item, if header, check hash:
|
||||
# this part of thevalidation will be specific per network & type and should
|
||||
# be thus be custom per network
|
||||
# TODO:
|
||||
# - Implement a way to discern different content items (e.g. length prefixed)
|
||||
# - Check amount of content items according to ContentKeysList
|
||||
# - The above could also live in `PortalStream`
|
||||
|
||||
# TODO: for now we only consider 1 item being offered
|
||||
if contentKeys.len() == 1:
|
||||
let contentKey = contentKeys[0]
|
||||
let contentIdOpt = p.toContentId(contentKey)
|
||||
if contentIdOpt.isNone():
|
||||
return
|
||||
if p.validateContent(content, contentKey):
|
||||
let contentIdOpt = p.toContentId(contentKey)
|
||||
if contentIdOpt.isNone():
|
||||
return
|
||||
|
||||
let contentId = contentIdOpt.get()
|
||||
# Store content, should we recheck radius?
|
||||
p.contentDB.put(contentId, content)
|
||||
let contentId = contentIdOpt.get()
|
||||
# Store content, should we recheck radius?
|
||||
p.contentDB.put(contentId, content)
|
||||
|
||||
asyncSpawn neighborhoodGossip(p, contentKeys)
|
||||
asyncSpawn neighborhoodGossip(p, contentKeys)
|
||||
else:
|
||||
error "Received invalid content", contentKey
|
||||
|
||||
proc lookupWorker(
|
||||
p: PortalProtocol, dst: Node, target: NodeId): Future[seq[Node]] {.async.} =
|
||||
@ -854,7 +863,7 @@ proc lookup*(p: PortalProtocol, target: NodeId): Future[seq[Node]] {.async.} =
|
||||
|
||||
proc triggerPoke*(
|
||||
p: PortalProtocol,
|
||||
nodes: seq[Node],
|
||||
nodes: seq[Node],
|
||||
contentKey: ByteList,
|
||||
content: seq[byte]) =
|
||||
## Triggers asynchronous offer-accept interaction to provided nodes.
|
||||
@ -871,7 +880,7 @@ proc triggerPoke*(
|
||||
raiseAssert(e.msg)
|
||||
else:
|
||||
# offer queue full, do not start more offer offer-accept interactions
|
||||
return
|
||||
return
|
||||
|
||||
# TODO ContentLookup and Lookup look almost exactly the same, also lookups in other
|
||||
# networks will probably be very similar. Extract lookup function to separate module
|
||||
@ -947,7 +956,7 @@ proc contentLookup*(p: PortalProtocol, target: ByteList, targetId: UInt256):
|
||||
|
||||
if closestNodes.len > BUCKET_SIZE:
|
||||
closestNodes.del(closestNodes.high())
|
||||
|
||||
|
||||
of Content:
|
||||
# cancel any pending queries as we have find the content
|
||||
for f in pendingQueries:
|
||||
|
@ -75,12 +75,25 @@ type
|
||||
udata: pointer
|
||||
contentHandler: ContentHandlerCallback
|
||||
|
||||
proc pruneAllowedConnections(stream: PortalStream) =
|
||||
# Prune requests and offers that didn't receive a connection request
|
||||
# before `connectionTimeout`.
|
||||
let now = Moment.now()
|
||||
stream.contentRequests.keepIf(proc(x: ContentRequest): bool =
|
||||
x.timeout > now)
|
||||
stream.contentOffers.keepIf(proc(x: ContentOffer): bool =
|
||||
x.timeout > now)
|
||||
|
||||
proc getUserData*[T](stream: PortalStream): T =
|
||||
## Obtain user data stored in ``stream`` object.
|
||||
cast[T](stream.udata)
|
||||
|
||||
proc addContentOffer*(
|
||||
stream: PortalStream, nodeId: NodeId, contentKeys: ContentKeysList): Bytes2 =
|
||||
stream.pruneAllowedConnections()
|
||||
|
||||
# TODO: Should we check if `NodeId` & `connectionId` combo already exists?
|
||||
# What happens if we get duplicates?
|
||||
var connectionId: Bytes2
|
||||
brHmacDrbgGenerate(stream.rng[], connectionId)
|
||||
|
||||
@ -97,6 +110,10 @@ proc addContentOffer*(
|
||||
|
||||
proc addContentRequest*(
|
||||
stream: PortalStream, nodeId: NodeId, content: seq[byte]): Bytes2 =
|
||||
stream.pruneAllowedConnections()
|
||||
|
||||
# TODO: Should we check if `NodeId` & `connectionId` combo already exists?
|
||||
# What happens if we get duplicates?
|
||||
var connectionId: Bytes2
|
||||
brHmacDrbgGenerate(stream.rng[], connectionId)
|
||||
|
||||
@ -194,15 +211,6 @@ proc new*(
|
||||
func setTransport*(stream: PortalStream, transport: UtpDiscv5Protocol) =
|
||||
stream.transport = transport
|
||||
|
||||
proc pruneAllowedConnections(stream: PortalStream) =
|
||||
# Prune requests and offers that didn't receive a connection request
|
||||
# before `connectionTimeout`.
|
||||
let now = Moment.now()
|
||||
stream.contentRequests.keepIf(proc(x: ContentRequest): bool =
|
||||
x.timeout > now)
|
||||
stream.contentOffers.keepIf(proc(x: ContentOffer): bool =
|
||||
x.timeout > now)
|
||||
|
||||
# TODO: I think I'd like it more if we weren't to capture the stream.
|
||||
proc registerIncomingSocketCallback*(
|
||||
streams: seq[PortalStream]): AcceptConnectionCallback[NodeAddress] =
|
||||
|
@ -39,6 +39,9 @@ proc testHandlerSha256(contentKey: ByteList): Option[ContentId] =
|
||||
let idHash = sha256.digest(contentKey.asSeq())
|
||||
some(readUintBE[256](idHash.data))
|
||||
|
||||
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
|
||||
true
|
||||
|
||||
proc defaultTestCase(rng: ref BrHmacDrbgContext): Default2NodeTest =
|
||||
let
|
||||
node1 = initDiscoveryNode(
|
||||
@ -49,8 +52,10 @@ proc defaultTestCase(rng: ref BrHmacDrbgContext): Default2NodeTest =
|
||||
db1 = ContentDB.new("", inMemory = true)
|
||||
db2 = ContentDB.new("", inMemory = true)
|
||||
|
||||
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler)
|
||||
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler)
|
||||
proto1 =
|
||||
PortalProtocol.new(node1, protocolId, db1, testHandler, validateContent)
|
||||
proto2 =
|
||||
PortalProtocol.new(node2, protocolId, db2, testHandler, validateContent)
|
||||
|
||||
Default2NodeTest(node1: node1, node2: node2, proto1: proto1, proto2: proto2)
|
||||
|
||||
@ -206,9 +211,12 @@ procSuite "Portal Wire Protocol Tests":
|
||||
db2 = ContentDB.new("", inMemory = true)
|
||||
db3 = ContentDB.new("", inMemory = true)
|
||||
|
||||
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler)
|
||||
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler)
|
||||
proto3 = PortalProtocol.new(node3, protocolId, db3, testHandler)
|
||||
proto1 = PortalProtocol.new(
|
||||
node1, protocolId, db1, testHandler, validateContent)
|
||||
proto2 = PortalProtocol.new(
|
||||
node2, protocolId, db2, testHandler, validateContent)
|
||||
proto3 = PortalProtocol.new(
|
||||
node3, protocolId, db3, testHandler, validateContent)
|
||||
|
||||
# Node1 knows about Node2, and Node2 knows about Node3 which hold all content
|
||||
check proto1.addNode(node2.localNode) == Added
|
||||
@ -239,14 +247,17 @@ procSuite "Portal Wire Protocol Tests":
|
||||
db2 = ContentDB.new("", inMemory = true)
|
||||
db3 = ContentDB.new("", inMemory = true)
|
||||
|
||||
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandlerSha256)
|
||||
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandlerSha256)
|
||||
proto3 = PortalProtocol.new(node3, protocolId, db3, testHandlerSha256)
|
||||
proto1 = PortalProtocol.new(
|
||||
node1, protocolId, db1, testHandlerSha256, validateContent)
|
||||
proto2 = PortalProtocol.new(
|
||||
node2, protocolId, db2, testHandlerSha256, validateContent)
|
||||
proto3 = PortalProtocol.new(
|
||||
node3, protocolId, db3, testHandlerSha256, validateContent)
|
||||
|
||||
content = @[byte 1, 2]
|
||||
contentList = List[byte, 2048].init(content)
|
||||
contentId = readUintBE[256](sha256.digest(content).data)
|
||||
|
||||
|
||||
# Only node3 have content
|
||||
db3.put(contentId, content)
|
||||
|
||||
@ -283,8 +294,10 @@ procSuite "Portal Wire Protocol Tests":
|
||||
db1 = ContentDB.new("", inMemory = true)
|
||||
db2 = ContentDB.new("", inMemory = true)
|
||||
|
||||
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler)
|
||||
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler,
|
||||
proto1 = PortalProtocol.new(
|
||||
node1, protocolId, db1, testHandler, validateContent)
|
||||
proto2 = PortalProtocol.new(
|
||||
node2, protocolId, db2, testHandler, validateContent,
|
||||
bootstrapRecords = [node1.localNode.record])
|
||||
|
||||
proto1.start()
|
||||
@ -307,7 +320,7 @@ procSuite "Portal Wire Protocol Tests":
|
||||
db = ContentDB.new("", inMemory = true)
|
||||
# No portal protocol for node1, hence an invalid bootstrap node
|
||||
proto2 = PortalProtocol.new(node2, protocolId, db, testHandler,
|
||||
bootstrapRecords = [node1.localNode.record])
|
||||
validateContent, bootstrapRecords = [node1.localNode.record])
|
||||
|
||||
# seedTable to add node1 to the routing table
|
||||
proto2.seedTable()
|
||||
|
@ -187,6 +187,9 @@ proc testHandler(contentKey: ByteList): Option[ContentId] =
|
||||
let idHash = sha256.digest("test")
|
||||
some(readUintBE[256](idHash.data))
|
||||
|
||||
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
|
||||
true
|
||||
|
||||
proc run(config: PortalCliConf) =
|
||||
let
|
||||
rng = newRng()
|
||||
@ -212,7 +215,8 @@ proc run(config: PortalCliConf) =
|
||||
|
||||
let
|
||||
db = ContentDB.new("", inMemory = true)
|
||||
portal = PortalProtocol.new(d, config.protocolId, db, testHandler,
|
||||
portal = PortalProtocol.new(d, config.protocolId, db,
|
||||
testHandler, validateContent,
|
||||
bootstrapRecords = bootstrapRecords)
|
||||
socketConfig = SocketConfig.init(
|
||||
incomingSocketReceiveTimeout = none(Duration))
|
||||
|
Loading…
x
Reference in New Issue
Block a user