Add validation handler to Portal wire protocol (#1055)

- Validation handler allows for network specific validation of
the content
- Added also pruning of the allowed connections on PortalStream
This commit is contained in:
Kim De Mey 2022-04-11 19:42:38 +02:00 committed by GitHub
parent 82aab2a404
commit ed38ed494f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 110 additions and 55 deletions

View File

@ -55,14 +55,14 @@ func getEncodedKeyForContent(
return encodeKey(contentKey)
proc validateHeaderBytes*(
bytes: seq[byte], hash: BlockHash): Option[BlockHeader] =
bytes: openArray[byte], hash: BlockHash): Option[BlockHeader] =
try:
var rlp = rlpFromBytes(bytes)
let blockHeader = rlp.read(BlockHeader)
if not (blockHeader.blockHash() == hash):
# TODO: Header with different hash than expecte, maybe we should punish
# TODO: Header with different hash than expected, maybe we should punish
# peer which sent us this ?
return none(BlockHeader)
@ -73,7 +73,7 @@ proc validateHeaderBytes*(
return none(BlockHeader)
proc validateBodyBytes*(
bytes: seq[byte], txRoot: KeccakHash, ommersHash: KeccakHash):
bytes: openArray[byte], txRoot: KeccakHash, ommersHash: KeccakHash):
Option[BlockBody] =
try:
var rlp = rlpFromBytes(bytes)
@ -192,7 +192,23 @@ proc getBlock*(
return some[Block]((header, blockBody))
# TODO Add getRecepits call
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
let keyOpt = contentKey.decode()
if keyOpt.isNone():
return false
let key = keyOpt.get()
case key.contentType:
of blockHeader:
validateHeaderBytes(content, key.blockHeaderKey.blockHash).isSome()
of blockBody:
true
# TODO: Need to get the header from the db or the network for this. Or how
# to deal with this?
of receipts:
true
proc new*(
T: type HistoryNetwork,
@ -202,7 +218,8 @@ proc new*(
bootstrapRecords: openArray[Record] = [],
portalConfig: PortalProtocolConfig = defaultPortalProtocolConfig): T =
let portalProtocol = PortalProtocol.new(
baseProtocol, historyProtocolId, contentDB, toContentIdHandler,
baseProtocol, historyProtocolId, contentDB,
toContentIdHandler, validateContent,
dataRadius, bootstrapRecords,
config = portalConfig)

View File

@ -56,6 +56,9 @@ proc getContent*(n: StateNetwork, key: ContentKey):
# domain types.
return some(contentResult.content)
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
true
proc new*(
T: type StateNetwork,
baseProtocol: protocol.Protocol,
@ -64,7 +67,8 @@ proc new*(
bootstrapRecords: openArray[Record] = [],
portalConfig: PortalProtocolConfig = defaultPortalProtocolConfig): T =
let portalProtocol = PortalProtocol.new(
baseProtocol, stateProtocolId, contentDB, toContentIdHandler,
baseProtocol, stateProtocolId, contentDB,
toContentIdHandler, validateContent,
dataRadius, bootstrapRecords, stateDistanceCalculator,
config = portalConfig)

View File

@ -117,6 +117,10 @@ type
ToContentIdHandler* =
proc(contentKey: ByteList): Option[ContentId] {.raises: [Defect], gcsafe.}
ContentValidationHandler* =
proc(content: openArray[byte], contentKey: ByteList):
bool {.raises: [Defect], gcsafe.}
PortalProtocolId* = array[2, byte]
RadiusCache* = LRUCache[NodeId, UInt256]
@ -142,6 +146,7 @@ type
baseProtocol*: protocol.Protocol
contentDB*: ContentDB
toContentId: ToContentIdHandler
validateContent: ContentValidationHandler
dataRadius*: UInt256
bootstrapRecords*: seq[Record]
lastLookup: chronos.Moment
@ -414,6 +419,7 @@ proc new*(T: type PortalProtocol,
protocolId: PortalProtocolId,
contentDB: ContentDB,
toContentId: ToContentIdHandler,
validateContent: ContentValidationHandler,
dataRadius = UInt256.high(),
bootstrapRecords: openArray[Record] = [],
distanceCalculator: DistanceCalculator = XorDistanceCalculator,
@ -429,6 +435,7 @@ proc new*(T: type PortalProtocol,
baseProtocol: baseProtocol,
contentDB: contentDB,
toContentId: toContentId,
validateContent: validateContent,
dataRadius: dataRadius,
bootstrapRecords: @bootstrapRecords,
radiusCache: RadiusCache.init(256),
@ -640,7 +647,7 @@ proc offer(p: PortalProtocol, o: OfferRequest):
## to many peers, and keeping it all in memory could exhaust node resources.
## Main drawback is that content may be deleted from the node database
## by the cleanup process before it will be transferred, so this way does not
## guarantee content transfer
## guarantee content transfer.
let contentKeys = getContentKeys(o)
let acceptMessageResponse = await p.offerImpl(o.dst, contentKeys)
@ -759,15 +766,15 @@ proc processContent(
{.gcsafe, raises: [Defect].} =
let p = getUserData[PortalProtocol](stream)
# TODO: validate content
# - check amount of content items according to ContentKeysList
# - History Network specific: each content item, if header, check hash:
# this part of thevalidation will be specific per network & type and should
# be thus be custom per network
# TODO:
# - Implement a way to discern different content items (e.g. length prefixed)
# - Check amount of content items according to ContentKeysList
# - The above could also live in `PortalStream`
# TODO: for now we only consider 1 item being offered
if contentKeys.len() == 1:
let contentKey = contentKeys[0]
if p.validateContent(content, contentKey):
let contentIdOpt = p.toContentId(contentKey)
if contentIdOpt.isNone():
return
@ -777,6 +784,8 @@ proc processContent(
p.contentDB.put(contentId, content)
asyncSpawn neighborhoodGossip(p, contentKeys)
else:
error "Received invalid content", contentKey
proc lookupWorker(
p: PortalProtocol, dst: Node, target: NodeId): Future[seq[Node]] {.async.} =

View File

@ -75,12 +75,25 @@ type
udata: pointer
contentHandler: ContentHandlerCallback
proc pruneAllowedConnections(stream: PortalStream) =
# Prune requests and offers that didn't receive a connection request
# before `connectionTimeout`.
let now = Moment.now()
stream.contentRequests.keepIf(proc(x: ContentRequest): bool =
x.timeout > now)
stream.contentOffers.keepIf(proc(x: ContentOffer): bool =
x.timeout > now)
proc getUserData*[T](stream: PortalStream): T =
## Obtain user data stored in ``stream`` object.
cast[T](stream.udata)
proc addContentOffer*(
stream: PortalStream, nodeId: NodeId, contentKeys: ContentKeysList): Bytes2 =
stream.pruneAllowedConnections()
# TODO: Should we check if `NodeId` & `connectionId` combo already exists?
# What happens if we get duplicates?
var connectionId: Bytes2
brHmacDrbgGenerate(stream.rng[], connectionId)
@ -97,6 +110,10 @@ proc addContentOffer*(
proc addContentRequest*(
stream: PortalStream, nodeId: NodeId, content: seq[byte]): Bytes2 =
stream.pruneAllowedConnections()
# TODO: Should we check if `NodeId` & `connectionId` combo already exists?
# What happens if we get duplicates?
var connectionId: Bytes2
brHmacDrbgGenerate(stream.rng[], connectionId)
@ -194,15 +211,6 @@ proc new*(
func setTransport*(stream: PortalStream, transport: UtpDiscv5Protocol) =
stream.transport = transport
proc pruneAllowedConnections(stream: PortalStream) =
# Prune requests and offers that didn't receive a connection request
# before `connectionTimeout`.
let now = Moment.now()
stream.contentRequests.keepIf(proc(x: ContentRequest): bool =
x.timeout > now)
stream.contentOffers.keepIf(proc(x: ContentOffer): bool =
x.timeout > now)
# TODO: I think I'd like it more if we weren't to capture the stream.
proc registerIncomingSocketCallback*(
streams: seq[PortalStream]): AcceptConnectionCallback[NodeAddress] =

View File

@ -39,6 +39,9 @@ proc testHandlerSha256(contentKey: ByteList): Option[ContentId] =
let idHash = sha256.digest(contentKey.asSeq())
some(readUintBE[256](idHash.data))
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
true
proc defaultTestCase(rng: ref BrHmacDrbgContext): Default2NodeTest =
let
node1 = initDiscoveryNode(
@ -49,8 +52,10 @@ proc defaultTestCase(rng: ref BrHmacDrbgContext): Default2NodeTest =
db1 = ContentDB.new("", inMemory = true)
db2 = ContentDB.new("", inMemory = true)
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler)
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler)
proto1 =
PortalProtocol.new(node1, protocolId, db1, testHandler, validateContent)
proto2 =
PortalProtocol.new(node2, protocolId, db2, testHandler, validateContent)
Default2NodeTest(node1: node1, node2: node2, proto1: proto1, proto2: proto2)
@ -206,9 +211,12 @@ procSuite "Portal Wire Protocol Tests":
db2 = ContentDB.new("", inMemory = true)
db3 = ContentDB.new("", inMemory = true)
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler)
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler)
proto3 = PortalProtocol.new(node3, protocolId, db3, testHandler)
proto1 = PortalProtocol.new(
node1, protocolId, db1, testHandler, validateContent)
proto2 = PortalProtocol.new(
node2, protocolId, db2, testHandler, validateContent)
proto3 = PortalProtocol.new(
node3, protocolId, db3, testHandler, validateContent)
# Node1 knows about Node2, and Node2 knows about Node3 which hold all content
check proto1.addNode(node2.localNode) == Added
@ -239,9 +247,12 @@ procSuite "Portal Wire Protocol Tests":
db2 = ContentDB.new("", inMemory = true)
db3 = ContentDB.new("", inMemory = true)
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandlerSha256)
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandlerSha256)
proto3 = PortalProtocol.new(node3, protocolId, db3, testHandlerSha256)
proto1 = PortalProtocol.new(
node1, protocolId, db1, testHandlerSha256, validateContent)
proto2 = PortalProtocol.new(
node2, protocolId, db2, testHandlerSha256, validateContent)
proto3 = PortalProtocol.new(
node3, protocolId, db3, testHandlerSha256, validateContent)
content = @[byte 1, 2]
contentList = List[byte, 2048].init(content)
@ -283,8 +294,10 @@ procSuite "Portal Wire Protocol Tests":
db1 = ContentDB.new("", inMemory = true)
db2 = ContentDB.new("", inMemory = true)
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler)
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler,
proto1 = PortalProtocol.new(
node1, protocolId, db1, testHandler, validateContent)
proto2 = PortalProtocol.new(
node2, protocolId, db2, testHandler, validateContent,
bootstrapRecords = [node1.localNode.record])
proto1.start()
@ -307,7 +320,7 @@ procSuite "Portal Wire Protocol Tests":
db = ContentDB.new("", inMemory = true)
# No portal protocol for node1, hence an invalid bootstrap node
proto2 = PortalProtocol.new(node2, protocolId, db, testHandler,
bootstrapRecords = [node1.localNode.record])
validateContent, bootstrapRecords = [node1.localNode.record])
# seedTable to add node1 to the routing table
proto2.seedTable()

View File

@ -187,6 +187,9 @@ proc testHandler(contentKey: ByteList): Option[ContentId] =
let idHash = sha256.digest("test")
some(readUintBE[256](idHash.data))
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
true
proc run(config: PortalCliConf) =
let
rng = newRng()
@ -212,7 +215,8 @@ proc run(config: PortalCliConf) =
let
db = ContentDB.new("", inMemory = true)
portal = PortalProtocol.new(d, config.protocolId, db, testHandler,
portal = PortalProtocol.new(d, config.protocolId, db,
testHandler, validateContent,
bootstrapRecords = bootstrapRecords)
socketConfig = SocketConfig.init(
incomingSocketReceiveTimeout = none(Duration))