Add validation handler to Portal wire protocol (#1055)

- Validation handler allows for network specific validation of
the content
- Added also pruning of the allowed connections on PortalStream
This commit is contained in:
Kim De Mey 2022-04-11 19:42:38 +02:00 committed by GitHub
parent 82aab2a404
commit ed38ed494f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 110 additions and 55 deletions

View File

@ -55,14 +55,14 @@ func getEncodedKeyForContent(
return encodeKey(contentKey)
proc validateHeaderBytes*(
bytes: seq[byte], hash: BlockHash): Option[BlockHeader] =
bytes: openArray[byte], hash: BlockHash): Option[BlockHeader] =
try:
var rlp = rlpFromBytes(bytes)
let blockHeader = rlp.read(BlockHeader)
if not (blockHeader.blockHash() == hash):
# TODO: Header with different hash than expecte, maybe we should punish
# TODO: Header with different hash than expected, maybe we should punish
# peer which sent us this ?
return none(BlockHeader)
@ -73,7 +73,7 @@ proc validateHeaderBytes*(
return none(BlockHeader)
proc validateBodyBytes*(
bytes: seq[byte], txRoot: KeccakHash, ommersHash: KeccakHash):
bytes: openArray[byte], txRoot: KeccakHash, ommersHash: KeccakHash):
Option[BlockBody] =
try:
var rlp = rlpFromBytes(bytes)
@ -136,7 +136,7 @@ proc getBlockHeader*(
# Content is valid we can propagate it to interested peers
h.portalProtocol.triggerPoke(
headerContent.nodesInterestedInContent,
keyEncoded,
keyEncoded,
headerContent.content
)
@ -182,7 +182,7 @@ proc getBlock*(
# body is valid, propagate it to interested peers
h.portalProtocol.triggerPoke(
bodyContent.nodesInterestedInContent,
keyEncoded,
keyEncoded,
bodyContent.content
)
@ -192,7 +192,23 @@ proc getBlock*(
return some[Block]((header, blockBody))
# TODO Add getRecepits call
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
let keyOpt = contentKey.decode()
if keyOpt.isNone():
return false
let key = keyOpt.get()
case key.contentType:
of blockHeader:
validateHeaderBytes(content, key.blockHeaderKey.blockHash).isSome()
of blockBody:
true
# TODO: Need to get the header from the db or the network for this. Or how
# to deal with this?
of receipts:
true
proc new*(
T: type HistoryNetwork,
@ -202,7 +218,8 @@ proc new*(
bootstrapRecords: openArray[Record] = [],
portalConfig: PortalProtocolConfig = defaultPortalProtocolConfig): T =
let portalProtocol = PortalProtocol.new(
baseProtocol, historyProtocolId, contentDB, toContentIdHandler,
baseProtocol, historyProtocolId, contentDB,
toContentIdHandler, validateContent,
dataRadius, bootstrapRecords,
config = portalConfig)

View File

@ -56,6 +56,9 @@ proc getContent*(n: StateNetwork, key: ContentKey):
# domain types.
return some(contentResult.content)
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
true
proc new*(
T: type StateNetwork,
baseProtocol: protocol.Protocol,
@ -64,7 +67,8 @@ proc new*(
bootstrapRecords: openArray[Record] = [],
portalConfig: PortalProtocolConfig = defaultPortalProtocolConfig): T =
let portalProtocol = PortalProtocol.new(
baseProtocol, stateProtocolId, contentDB, toContentIdHandler,
baseProtocol, stateProtocolId, contentDB,
toContentIdHandler, validateContent,
dataRadius, bootstrapRecords, stateDistanceCalculator,
config = portalConfig)

View File

@ -117,10 +117,14 @@ type
ToContentIdHandler* =
proc(contentKey: ByteList): Option[ContentId] {.raises: [Defect], gcsafe.}
ContentValidationHandler* =
proc(content: openArray[byte], contentKey: ByteList):
bool {.raises: [Defect], gcsafe.}
PortalProtocolId* = array[2, byte]
RadiusCache* = LRUCache[NodeId, UInt256]
ContentInfo* = object
contentKey*: ByteList
content*: seq[byte]
@ -142,6 +146,7 @@ type
baseProtocol*: protocol.Protocol
contentDB*: ContentDB
toContentId: ToContentIdHandler
validateContent: ContentValidationHandler
dataRadius*: UInt256
bootstrapRecords*: seq[Record]
lastLookup: chronos.Moment
@ -173,7 +178,7 @@ type
nodesInterestedInContent*: seq[Node]
proc init*(
T: type ContentInfo,
T: type ContentInfo,
contentKey: ByteList,
content: seq[byte]): T =
ContentInfo(
@ -182,11 +187,11 @@ proc init*(
)
proc init*(
T: type ContentLookupResult,
content: seq[byte],
T: type ContentLookupResult,
content: seq[byte],
nodesInterestedInContent: seq[Node]): T =
ContentLookupResult(
content: content,
content: content,
nodesInterestedInContent: nodesInterestedInContent
)
@ -210,8 +215,8 @@ func neighbours*(p: PortalProtocol, id: NodeId, seenOnly = false): seq[Node] =
proc inRange(
p: PortalProtocol,
nodeId: NodeId,
nodeRadius: Uint256,
nodeId: NodeId,
nodeRadius: Uint256,
contentId: ContentId): bool =
let distance = p.routingTable.distance(nodeId, contentId)
distance <= nodeRadius
@ -414,6 +419,7 @@ proc new*(T: type PortalProtocol,
protocolId: PortalProtocolId,
contentDB: ContentDB,
toContentId: ToContentIdHandler,
validateContent: ContentValidationHandler,
dataRadius = UInt256.high(),
bootstrapRecords: openArray[Record] = [],
distanceCalculator: DistanceCalculator = XorDistanceCalculator,
@ -429,6 +435,7 @@ proc new*(T: type PortalProtocol,
baseProtocol: baseProtocol,
contentDB: contentDB,
toContentId: toContentId,
validateContent: validateContent,
dataRadius: dataRadius,
bootstrapRecords: @bootstrapRecords,
radiusCache: RadiusCache.init(256),
@ -612,7 +619,7 @@ proc findContent*(p: PortalProtocol, dst: Node, contentKey: ByteList):
else:
return err("Content message returned invalid ENRs")
proc getContentKeys(o: OfferRequest): ContentKeysList =
proc getContentKeys(o: OfferRequest): ContentKeysList =
case o.kind
of Direct:
var contentKeys:ContentKeysList
@ -640,7 +647,7 @@ proc offer(p: PortalProtocol, o: OfferRequest):
## to many peers, and keeping it all in memory could exhaust node resources.
## Main drawback is that content may be deleted from the node database
## by the cleanup process before it will be transferred, so this way does not
## guarantee content transfer
## guarantee content transfer.
let contentKeys = getContentKeys(o)
let acceptMessageResponse = await p.offerImpl(o.dst, contentKeys)
@ -716,7 +723,7 @@ proc offer*(p: PortalProtocol, dst: Node, content: seq[ContentInfo]):
Future[PortalResult[void]] {.async.} =
if len(content) > contentKeysLimit:
return err("Cannot offer more than 64 content items")
let contentList = List[ContentInfo, contentKeysLimit].init(content)
let req = OfferRequest(dst: dst, kind: Direct, contentList: contentList)
let res = await p.offer(req)
@ -759,24 +766,26 @@ proc processContent(
{.gcsafe, raises: [Defect].} =
let p = getUserData[PortalProtocol](stream)
# TODO: validate content
# - check amount of content items according to ContentKeysList
# - History Network specific: each content item, if header, check hash:
# this part of thevalidation will be specific per network & type and should
# be thus be custom per network
# TODO:
# - Implement a way to discern different content items (e.g. length prefixed)
# - Check amount of content items according to ContentKeysList
# - The above could also live in `PortalStream`
# TODO: for now we only consider 1 item being offered
if contentKeys.len() == 1:
let contentKey = contentKeys[0]
let contentIdOpt = p.toContentId(contentKey)
if contentIdOpt.isNone():
return
if p.validateContent(content, contentKey):
let contentIdOpt = p.toContentId(contentKey)
if contentIdOpt.isNone():
return
let contentId = contentIdOpt.get()
# Store content, should we recheck radius?
p.contentDB.put(contentId, content)
let contentId = contentIdOpt.get()
# Store content, should we recheck radius?
p.contentDB.put(contentId, content)
asyncSpawn neighborhoodGossip(p, contentKeys)
asyncSpawn neighborhoodGossip(p, contentKeys)
else:
error "Received invalid content", contentKey
proc lookupWorker(
p: PortalProtocol, dst: Node, target: NodeId): Future[seq[Node]] {.async.} =
@ -854,7 +863,7 @@ proc lookup*(p: PortalProtocol, target: NodeId): Future[seq[Node]] {.async.} =
proc triggerPoke*(
p: PortalProtocol,
nodes: seq[Node],
nodes: seq[Node],
contentKey: ByteList,
content: seq[byte]) =
## Triggers asynchronous offer-accept interaction to provided nodes.
@ -871,7 +880,7 @@ proc triggerPoke*(
raiseAssert(e.msg)
else:
# offer queue full, do not start more offer offer-accept interactions
return
return
# TODO ContentLookup and Lookup look almost exactly the same, also lookups in other
# networks will probably be very similar. Extract lookup function to separate module
@ -947,7 +956,7 @@ proc contentLookup*(p: PortalProtocol, target: ByteList, targetId: UInt256):
if closestNodes.len > BUCKET_SIZE:
closestNodes.del(closestNodes.high())
of Content:
# cancel any pending queries as we have find the content
for f in pendingQueries:

View File

@ -75,12 +75,25 @@ type
udata: pointer
contentHandler: ContentHandlerCallback
proc pruneAllowedConnections(stream: PortalStream) =
# Prune requests and offers that didn't receive a connection request
# before `connectionTimeout`.
let now = Moment.now()
stream.contentRequests.keepIf(proc(x: ContentRequest): bool =
x.timeout > now)
stream.contentOffers.keepIf(proc(x: ContentOffer): bool =
x.timeout > now)
proc getUserData*[T](stream: PortalStream): T =
## Obtain user data stored in ``stream`` object.
cast[T](stream.udata)
proc addContentOffer*(
stream: PortalStream, nodeId: NodeId, contentKeys: ContentKeysList): Bytes2 =
stream.pruneAllowedConnections()
# TODO: Should we check if `NodeId` & `connectionId` combo already exists?
# What happens if we get duplicates?
var connectionId: Bytes2
brHmacDrbgGenerate(stream.rng[], connectionId)
@ -97,6 +110,10 @@ proc addContentOffer*(
proc addContentRequest*(
stream: PortalStream, nodeId: NodeId, content: seq[byte]): Bytes2 =
stream.pruneAllowedConnections()
# TODO: Should we check if `NodeId` & `connectionId` combo already exists?
# What happens if we get duplicates?
var connectionId: Bytes2
brHmacDrbgGenerate(stream.rng[], connectionId)
@ -194,15 +211,6 @@ proc new*(
func setTransport*(stream: PortalStream, transport: UtpDiscv5Protocol) =
stream.transport = transport
proc pruneAllowedConnections(stream: PortalStream) =
# Prune requests and offers that didn't receive a connection request
# before `connectionTimeout`.
let now = Moment.now()
stream.contentRequests.keepIf(proc(x: ContentRequest): bool =
x.timeout > now)
stream.contentOffers.keepIf(proc(x: ContentOffer): bool =
x.timeout > now)
# TODO: I think I'd like it more if we weren't to capture the stream.
proc registerIncomingSocketCallback*(
streams: seq[PortalStream]): AcceptConnectionCallback[NodeAddress] =

View File

@ -39,6 +39,9 @@ proc testHandlerSha256(contentKey: ByteList): Option[ContentId] =
let idHash = sha256.digest(contentKey.asSeq())
some(readUintBE[256](idHash.data))
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
true
proc defaultTestCase(rng: ref BrHmacDrbgContext): Default2NodeTest =
let
node1 = initDiscoveryNode(
@ -49,8 +52,10 @@ proc defaultTestCase(rng: ref BrHmacDrbgContext): Default2NodeTest =
db1 = ContentDB.new("", inMemory = true)
db2 = ContentDB.new("", inMemory = true)
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler)
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler)
proto1 =
PortalProtocol.new(node1, protocolId, db1, testHandler, validateContent)
proto2 =
PortalProtocol.new(node2, protocolId, db2, testHandler, validateContent)
Default2NodeTest(node1: node1, node2: node2, proto1: proto1, proto2: proto2)
@ -206,9 +211,12 @@ procSuite "Portal Wire Protocol Tests":
db2 = ContentDB.new("", inMemory = true)
db3 = ContentDB.new("", inMemory = true)
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler)
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler)
proto3 = PortalProtocol.new(node3, protocolId, db3, testHandler)
proto1 = PortalProtocol.new(
node1, protocolId, db1, testHandler, validateContent)
proto2 = PortalProtocol.new(
node2, protocolId, db2, testHandler, validateContent)
proto3 = PortalProtocol.new(
node3, protocolId, db3, testHandler, validateContent)
# Node1 knows about Node2, and Node2 knows about Node3 which hold all content
check proto1.addNode(node2.localNode) == Added
@ -239,14 +247,17 @@ procSuite "Portal Wire Protocol Tests":
db2 = ContentDB.new("", inMemory = true)
db3 = ContentDB.new("", inMemory = true)
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandlerSha256)
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandlerSha256)
proto3 = PortalProtocol.new(node3, protocolId, db3, testHandlerSha256)
proto1 = PortalProtocol.new(
node1, protocolId, db1, testHandlerSha256, validateContent)
proto2 = PortalProtocol.new(
node2, protocolId, db2, testHandlerSha256, validateContent)
proto3 = PortalProtocol.new(
node3, protocolId, db3, testHandlerSha256, validateContent)
content = @[byte 1, 2]
contentList = List[byte, 2048].init(content)
contentId = readUintBE[256](sha256.digest(content).data)
# Only node3 have content
db3.put(contentId, content)
@ -283,8 +294,10 @@ procSuite "Portal Wire Protocol Tests":
db1 = ContentDB.new("", inMemory = true)
db2 = ContentDB.new("", inMemory = true)
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler)
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler,
proto1 = PortalProtocol.new(
node1, protocolId, db1, testHandler, validateContent)
proto2 = PortalProtocol.new(
node2, protocolId, db2, testHandler, validateContent,
bootstrapRecords = [node1.localNode.record])
proto1.start()
@ -307,7 +320,7 @@ procSuite "Portal Wire Protocol Tests":
db = ContentDB.new("", inMemory = true)
# No portal protocol for node1, hence an invalid bootstrap node
proto2 = PortalProtocol.new(node2, protocolId, db, testHandler,
bootstrapRecords = [node1.localNode.record])
validateContent, bootstrapRecords = [node1.localNode.record])
# seedTable to add node1 to the routing table
proto2.seedTable()

View File

@ -187,6 +187,9 @@ proc testHandler(contentKey: ByteList): Option[ContentId] =
let idHash = sha256.digest("test")
some(readUintBE[256](idHash.data))
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
true
proc run(config: PortalCliConf) =
let
rng = newRng()
@ -212,7 +215,8 @@ proc run(config: PortalCliConf) =
let
db = ContentDB.new("", inMemory = true)
portal = PortalProtocol.new(d, config.protocolId, db, testHandler,
portal = PortalProtocol.new(d, config.protocolId, db,
testHandler, validateContent,
bootstrapRecords = bootstrapRecords)
socketConfig = SocketConfig.init(
incomingSocketReceiveTimeout = none(Duration))