Add validation handler to Portal wire protocol (#1055)
- Validation handler allows for network specific validation of the content - Added also pruning of the allowed connections on PortalStream
This commit is contained in:
parent
82aab2a404
commit
ed38ed494f
|
@ -55,14 +55,14 @@ func getEncodedKeyForContent(
|
|||
return encodeKey(contentKey)
|
||||
|
||||
proc validateHeaderBytes*(
|
||||
bytes: seq[byte], hash: BlockHash): Option[BlockHeader] =
|
||||
bytes: openArray[byte], hash: BlockHash): Option[BlockHeader] =
|
||||
try:
|
||||
var rlp = rlpFromBytes(bytes)
|
||||
|
||||
let blockHeader = rlp.read(BlockHeader)
|
||||
|
||||
if not (blockHeader.blockHash() == hash):
|
||||
# TODO: Header with different hash than expecte, maybe we should punish
|
||||
# TODO: Header with different hash than expected, maybe we should punish
|
||||
# peer which sent us this ?
|
||||
return none(BlockHeader)
|
||||
|
||||
|
@ -73,7 +73,7 @@ proc validateHeaderBytes*(
|
|||
return none(BlockHeader)
|
||||
|
||||
proc validateBodyBytes*(
|
||||
bytes: seq[byte], txRoot: KeccakHash, ommersHash: KeccakHash):
|
||||
bytes: openArray[byte], txRoot: KeccakHash, ommersHash: KeccakHash):
|
||||
Option[BlockBody] =
|
||||
try:
|
||||
var rlp = rlpFromBytes(bytes)
|
||||
|
@ -192,7 +192,23 @@ proc getBlock*(
|
|||
|
||||
return some[Block]((header, blockBody))
|
||||
|
||||
# TODO Add getRecepits call
|
||||
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
|
||||
let keyOpt = contentKey.decode()
|
||||
|
||||
if keyOpt.isNone():
|
||||
return false
|
||||
|
||||
let key = keyOpt.get()
|
||||
|
||||
case key.contentType:
|
||||
of blockHeader:
|
||||
validateHeaderBytes(content, key.blockHeaderKey.blockHash).isSome()
|
||||
of blockBody:
|
||||
true
|
||||
# TODO: Need to get the header from the db or the network for this. Or how
|
||||
# to deal with this?
|
||||
of receipts:
|
||||
true
|
||||
|
||||
proc new*(
|
||||
T: type HistoryNetwork,
|
||||
|
@ -202,7 +218,8 @@ proc new*(
|
|||
bootstrapRecords: openArray[Record] = [],
|
||||
portalConfig: PortalProtocolConfig = defaultPortalProtocolConfig): T =
|
||||
let portalProtocol = PortalProtocol.new(
|
||||
baseProtocol, historyProtocolId, contentDB, toContentIdHandler,
|
||||
baseProtocol, historyProtocolId, contentDB,
|
||||
toContentIdHandler, validateContent,
|
||||
dataRadius, bootstrapRecords,
|
||||
config = portalConfig)
|
||||
|
||||
|
|
|
@ -56,6 +56,9 @@ proc getContent*(n: StateNetwork, key: ContentKey):
|
|||
# domain types.
|
||||
return some(contentResult.content)
|
||||
|
||||
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
|
||||
true
|
||||
|
||||
proc new*(
|
||||
T: type StateNetwork,
|
||||
baseProtocol: protocol.Protocol,
|
||||
|
@ -64,7 +67,8 @@ proc new*(
|
|||
bootstrapRecords: openArray[Record] = [],
|
||||
portalConfig: PortalProtocolConfig = defaultPortalProtocolConfig): T =
|
||||
let portalProtocol = PortalProtocol.new(
|
||||
baseProtocol, stateProtocolId, contentDB, toContentIdHandler,
|
||||
baseProtocol, stateProtocolId, contentDB,
|
||||
toContentIdHandler, validateContent,
|
||||
dataRadius, bootstrapRecords, stateDistanceCalculator,
|
||||
config = portalConfig)
|
||||
|
||||
|
|
|
@ -117,6 +117,10 @@ type
|
|||
ToContentIdHandler* =
|
||||
proc(contentKey: ByteList): Option[ContentId] {.raises: [Defect], gcsafe.}
|
||||
|
||||
ContentValidationHandler* =
|
||||
proc(content: openArray[byte], contentKey: ByteList):
|
||||
bool {.raises: [Defect], gcsafe.}
|
||||
|
||||
PortalProtocolId* = array[2, byte]
|
||||
|
||||
RadiusCache* = LRUCache[NodeId, UInt256]
|
||||
|
@ -142,6 +146,7 @@ type
|
|||
baseProtocol*: protocol.Protocol
|
||||
contentDB*: ContentDB
|
||||
toContentId: ToContentIdHandler
|
||||
validateContent: ContentValidationHandler
|
||||
dataRadius*: UInt256
|
||||
bootstrapRecords*: seq[Record]
|
||||
lastLookup: chronos.Moment
|
||||
|
@ -414,6 +419,7 @@ proc new*(T: type PortalProtocol,
|
|||
protocolId: PortalProtocolId,
|
||||
contentDB: ContentDB,
|
||||
toContentId: ToContentIdHandler,
|
||||
validateContent: ContentValidationHandler,
|
||||
dataRadius = UInt256.high(),
|
||||
bootstrapRecords: openArray[Record] = [],
|
||||
distanceCalculator: DistanceCalculator = XorDistanceCalculator,
|
||||
|
@ -429,6 +435,7 @@ proc new*(T: type PortalProtocol,
|
|||
baseProtocol: baseProtocol,
|
||||
contentDB: contentDB,
|
||||
toContentId: toContentId,
|
||||
validateContent: validateContent,
|
||||
dataRadius: dataRadius,
|
||||
bootstrapRecords: @bootstrapRecords,
|
||||
radiusCache: RadiusCache.init(256),
|
||||
|
@ -640,7 +647,7 @@ proc offer(p: PortalProtocol, o: OfferRequest):
|
|||
## to many peers, and keeping it all in memory could exhaust node resources.
|
||||
## Main drawback is that content may be deleted from the node database
|
||||
## by the cleanup process before it will be transferred, so this way does not
|
||||
## guarantee content transfer
|
||||
## guarantee content transfer.
|
||||
let contentKeys = getContentKeys(o)
|
||||
|
||||
let acceptMessageResponse = await p.offerImpl(o.dst, contentKeys)
|
||||
|
@ -759,15 +766,15 @@ proc processContent(
|
|||
{.gcsafe, raises: [Defect].} =
|
||||
let p = getUserData[PortalProtocol](stream)
|
||||
|
||||
# TODO: validate content
|
||||
# - check amount of content items according to ContentKeysList
|
||||
# - History Network specific: each content item, if header, check hash:
|
||||
# this part of thevalidation will be specific per network & type and should
|
||||
# be thus be custom per network
|
||||
# TODO:
|
||||
# - Implement a way to discern different content items (e.g. length prefixed)
|
||||
# - Check amount of content items according to ContentKeysList
|
||||
# - The above could also live in `PortalStream`
|
||||
|
||||
# TODO: for now we only consider 1 item being offered
|
||||
if contentKeys.len() == 1:
|
||||
let contentKey = contentKeys[0]
|
||||
if p.validateContent(content, contentKey):
|
||||
let contentIdOpt = p.toContentId(contentKey)
|
||||
if contentIdOpt.isNone():
|
||||
return
|
||||
|
@ -777,6 +784,8 @@ proc processContent(
|
|||
p.contentDB.put(contentId, content)
|
||||
|
||||
asyncSpawn neighborhoodGossip(p, contentKeys)
|
||||
else:
|
||||
error "Received invalid content", contentKey
|
||||
|
||||
proc lookupWorker(
|
||||
p: PortalProtocol, dst: Node, target: NodeId): Future[seq[Node]] {.async.} =
|
||||
|
|
|
@ -75,12 +75,25 @@ type
|
|||
udata: pointer
|
||||
contentHandler: ContentHandlerCallback
|
||||
|
||||
proc pruneAllowedConnections(stream: PortalStream) =
|
||||
# Prune requests and offers that didn't receive a connection request
|
||||
# before `connectionTimeout`.
|
||||
let now = Moment.now()
|
||||
stream.contentRequests.keepIf(proc(x: ContentRequest): bool =
|
||||
x.timeout > now)
|
||||
stream.contentOffers.keepIf(proc(x: ContentOffer): bool =
|
||||
x.timeout > now)
|
||||
|
||||
proc getUserData*[T](stream: PortalStream): T =
|
||||
## Obtain user data stored in ``stream`` object.
|
||||
cast[T](stream.udata)
|
||||
|
||||
proc addContentOffer*(
|
||||
stream: PortalStream, nodeId: NodeId, contentKeys: ContentKeysList): Bytes2 =
|
||||
stream.pruneAllowedConnections()
|
||||
|
||||
# TODO: Should we check if `NodeId` & `connectionId` combo already exists?
|
||||
# What happens if we get duplicates?
|
||||
var connectionId: Bytes2
|
||||
brHmacDrbgGenerate(stream.rng[], connectionId)
|
||||
|
||||
|
@ -97,6 +110,10 @@ proc addContentOffer*(
|
|||
|
||||
proc addContentRequest*(
|
||||
stream: PortalStream, nodeId: NodeId, content: seq[byte]): Bytes2 =
|
||||
stream.pruneAllowedConnections()
|
||||
|
||||
# TODO: Should we check if `NodeId` & `connectionId` combo already exists?
|
||||
# What happens if we get duplicates?
|
||||
var connectionId: Bytes2
|
||||
brHmacDrbgGenerate(stream.rng[], connectionId)
|
||||
|
||||
|
@ -194,15 +211,6 @@ proc new*(
|
|||
func setTransport*(stream: PortalStream, transport: UtpDiscv5Protocol) =
|
||||
stream.transport = transport
|
||||
|
||||
proc pruneAllowedConnections(stream: PortalStream) =
|
||||
# Prune requests and offers that didn't receive a connection request
|
||||
# before `connectionTimeout`.
|
||||
let now = Moment.now()
|
||||
stream.contentRequests.keepIf(proc(x: ContentRequest): bool =
|
||||
x.timeout > now)
|
||||
stream.contentOffers.keepIf(proc(x: ContentOffer): bool =
|
||||
x.timeout > now)
|
||||
|
||||
# TODO: I think I'd like it more if we weren't to capture the stream.
|
||||
proc registerIncomingSocketCallback*(
|
||||
streams: seq[PortalStream]): AcceptConnectionCallback[NodeAddress] =
|
||||
|
|
|
@ -39,6 +39,9 @@ proc testHandlerSha256(contentKey: ByteList): Option[ContentId] =
|
|||
let idHash = sha256.digest(contentKey.asSeq())
|
||||
some(readUintBE[256](idHash.data))
|
||||
|
||||
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
|
||||
true
|
||||
|
||||
proc defaultTestCase(rng: ref BrHmacDrbgContext): Default2NodeTest =
|
||||
let
|
||||
node1 = initDiscoveryNode(
|
||||
|
@ -49,8 +52,10 @@ proc defaultTestCase(rng: ref BrHmacDrbgContext): Default2NodeTest =
|
|||
db1 = ContentDB.new("", inMemory = true)
|
||||
db2 = ContentDB.new("", inMemory = true)
|
||||
|
||||
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler)
|
||||
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler)
|
||||
proto1 =
|
||||
PortalProtocol.new(node1, protocolId, db1, testHandler, validateContent)
|
||||
proto2 =
|
||||
PortalProtocol.new(node2, protocolId, db2, testHandler, validateContent)
|
||||
|
||||
Default2NodeTest(node1: node1, node2: node2, proto1: proto1, proto2: proto2)
|
||||
|
||||
|
@ -206,9 +211,12 @@ procSuite "Portal Wire Protocol Tests":
|
|||
db2 = ContentDB.new("", inMemory = true)
|
||||
db3 = ContentDB.new("", inMemory = true)
|
||||
|
||||
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler)
|
||||
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler)
|
||||
proto3 = PortalProtocol.new(node3, protocolId, db3, testHandler)
|
||||
proto1 = PortalProtocol.new(
|
||||
node1, protocolId, db1, testHandler, validateContent)
|
||||
proto2 = PortalProtocol.new(
|
||||
node2, protocolId, db2, testHandler, validateContent)
|
||||
proto3 = PortalProtocol.new(
|
||||
node3, protocolId, db3, testHandler, validateContent)
|
||||
|
||||
# Node1 knows about Node2, and Node2 knows about Node3 which hold all content
|
||||
check proto1.addNode(node2.localNode) == Added
|
||||
|
@ -239,9 +247,12 @@ procSuite "Portal Wire Protocol Tests":
|
|||
db2 = ContentDB.new("", inMemory = true)
|
||||
db3 = ContentDB.new("", inMemory = true)
|
||||
|
||||
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandlerSha256)
|
||||
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandlerSha256)
|
||||
proto3 = PortalProtocol.new(node3, protocolId, db3, testHandlerSha256)
|
||||
proto1 = PortalProtocol.new(
|
||||
node1, protocolId, db1, testHandlerSha256, validateContent)
|
||||
proto2 = PortalProtocol.new(
|
||||
node2, protocolId, db2, testHandlerSha256, validateContent)
|
||||
proto3 = PortalProtocol.new(
|
||||
node3, protocolId, db3, testHandlerSha256, validateContent)
|
||||
|
||||
content = @[byte 1, 2]
|
||||
contentList = List[byte, 2048].init(content)
|
||||
|
@ -283,8 +294,10 @@ procSuite "Portal Wire Protocol Tests":
|
|||
db1 = ContentDB.new("", inMemory = true)
|
||||
db2 = ContentDB.new("", inMemory = true)
|
||||
|
||||
proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler)
|
||||
proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler,
|
||||
proto1 = PortalProtocol.new(
|
||||
node1, protocolId, db1, testHandler, validateContent)
|
||||
proto2 = PortalProtocol.new(
|
||||
node2, protocolId, db2, testHandler, validateContent,
|
||||
bootstrapRecords = [node1.localNode.record])
|
||||
|
||||
proto1.start()
|
||||
|
@ -307,7 +320,7 @@ procSuite "Portal Wire Protocol Tests":
|
|||
db = ContentDB.new("", inMemory = true)
|
||||
# No portal protocol for node1, hence an invalid bootstrap node
|
||||
proto2 = PortalProtocol.new(node2, protocolId, db, testHandler,
|
||||
bootstrapRecords = [node1.localNode.record])
|
||||
validateContent, bootstrapRecords = [node1.localNode.record])
|
||||
|
||||
# seedTable to add node1 to the routing table
|
||||
proto2.seedTable()
|
||||
|
|
|
@ -187,6 +187,9 @@ proc testHandler(contentKey: ByteList): Option[ContentId] =
|
|||
let idHash = sha256.digest("test")
|
||||
some(readUintBE[256](idHash.data))
|
||||
|
||||
proc validateContent(content: openArray[byte], contentKey: ByteList): bool =
|
||||
true
|
||||
|
||||
proc run(config: PortalCliConf) =
|
||||
let
|
||||
rng = newRng()
|
||||
|
@ -212,7 +215,8 @@ proc run(config: PortalCliConf) =
|
|||
|
||||
let
|
||||
db = ContentDB.new("", inMemory = true)
|
||||
portal = PortalProtocol.new(d, config.protocolId, db, testHandler,
|
||||
portal = PortalProtocol.new(d, config.protocolId, db,
|
||||
testHandler, validateContent,
|
||||
bootstrapRecords = bootstrapRecords)
|
||||
socketConfig = SocketConfig.init(
|
||||
incomingSocketReceiveTimeout = none(Duration))
|
||||
|
|
Loading…
Reference in New Issue