diff --git a/fluffy/network/history/history_network.nim b/fluffy/network/history/history_network.nim index 706bedc09..3bb081171 100644 --- a/fluffy/network/history/history_network.nim +++ b/fluffy/network/history/history_network.nim @@ -55,14 +55,14 @@ func getEncodedKeyForContent( return encodeKey(contentKey) proc validateHeaderBytes*( - bytes: seq[byte], hash: BlockHash): Option[BlockHeader] = + bytes: openArray[byte], hash: BlockHash): Option[BlockHeader] = try: var rlp = rlpFromBytes(bytes) let blockHeader = rlp.read(BlockHeader) if not (blockHeader.blockHash() == hash): - # TODO: Header with different hash than expecte, maybe we should punish + # TODO: Header with different hash than expected, maybe we should punish # peer which sent us this ? return none(BlockHeader) @@ -73,7 +73,7 @@ proc validateHeaderBytes*( return none(BlockHeader) proc validateBodyBytes*( - bytes: seq[byte], txRoot: KeccakHash, ommersHash: KeccakHash): + bytes: openArray[byte], txRoot: KeccakHash, ommersHash: KeccakHash): Option[BlockBody] = try: var rlp = rlpFromBytes(bytes) @@ -136,7 +136,7 @@ proc getBlockHeader*( # Content is valid we can propagate it to interested peers h.portalProtocol.triggerPoke( headerContent.nodesInterestedInContent, - keyEncoded, + keyEncoded, headerContent.content ) @@ -182,7 +182,7 @@ proc getBlock*( # body is valid, propagate it to interested peers h.portalProtocol.triggerPoke( bodyContent.nodesInterestedInContent, - keyEncoded, + keyEncoded, bodyContent.content ) @@ -192,7 +192,23 @@ proc getBlock*( return some[Block]((header, blockBody)) -# TODO Add getRecepits call +proc validateContent(content: openArray[byte], contentKey: ByteList): bool = + let keyOpt = contentKey.decode() + + if keyOpt.isNone(): + return false + + let key = keyOpt.get() + + case key.contentType: + of blockHeader: + validateHeaderBytes(content, key.blockHeaderKey.blockHash).isSome() + of blockBody: + true + # TODO: Need to get the header from the db or the network for this. Or how + # to deal with this? + of receipts: + true proc new*( T: type HistoryNetwork, @@ -202,7 +218,8 @@ proc new*( bootstrapRecords: openArray[Record] = [], portalConfig: PortalProtocolConfig = defaultPortalProtocolConfig): T = let portalProtocol = PortalProtocol.new( - baseProtocol, historyProtocolId, contentDB, toContentIdHandler, + baseProtocol, historyProtocolId, contentDB, + toContentIdHandler, validateContent, dataRadius, bootstrapRecords, config = portalConfig) diff --git a/fluffy/network/state/state_network.nim b/fluffy/network/state/state_network.nim index 62141fa48..f1248d62f 100644 --- a/fluffy/network/state/state_network.nim +++ b/fluffy/network/state/state_network.nim @@ -56,6 +56,9 @@ proc getContent*(n: StateNetwork, key: ContentKey): # domain types. return some(contentResult.content) +proc validateContent(content: openArray[byte], contentKey: ByteList): bool = + true + proc new*( T: type StateNetwork, baseProtocol: protocol.Protocol, @@ -64,7 +67,8 @@ proc new*( bootstrapRecords: openArray[Record] = [], portalConfig: PortalProtocolConfig = defaultPortalProtocolConfig): T = let portalProtocol = PortalProtocol.new( - baseProtocol, stateProtocolId, contentDB, toContentIdHandler, + baseProtocol, stateProtocolId, contentDB, + toContentIdHandler, validateContent, dataRadius, bootstrapRecords, stateDistanceCalculator, config = portalConfig) diff --git a/fluffy/network/wire/portal_protocol.nim b/fluffy/network/wire/portal_protocol.nim index 6cefdebae..06f37c4fb 100644 --- a/fluffy/network/wire/portal_protocol.nim +++ b/fluffy/network/wire/portal_protocol.nim @@ -117,10 +117,14 @@ type ToContentIdHandler* = proc(contentKey: ByteList): Option[ContentId] {.raises: [Defect], gcsafe.} + ContentValidationHandler* = + proc(content: openArray[byte], contentKey: ByteList): + bool {.raises: [Defect], gcsafe.} + PortalProtocolId* = array[2, byte] RadiusCache* = LRUCache[NodeId, UInt256] - + ContentInfo* = object contentKey*: ByteList content*: seq[byte] @@ -142,6 +146,7 @@ type baseProtocol*: protocol.Protocol contentDB*: ContentDB toContentId: ToContentIdHandler + validateContent: ContentValidationHandler dataRadius*: UInt256 bootstrapRecords*: seq[Record] lastLookup: chronos.Moment @@ -173,7 +178,7 @@ type nodesInterestedInContent*: seq[Node] proc init*( - T: type ContentInfo, + T: type ContentInfo, contentKey: ByteList, content: seq[byte]): T = ContentInfo( @@ -182,11 +187,11 @@ proc init*( ) proc init*( - T: type ContentLookupResult, - content: seq[byte], + T: type ContentLookupResult, + content: seq[byte], nodesInterestedInContent: seq[Node]): T = ContentLookupResult( - content: content, + content: content, nodesInterestedInContent: nodesInterestedInContent ) @@ -210,8 +215,8 @@ func neighbours*(p: PortalProtocol, id: NodeId, seenOnly = false): seq[Node] = proc inRange( p: PortalProtocol, - nodeId: NodeId, - nodeRadius: Uint256, + nodeId: NodeId, + nodeRadius: Uint256, contentId: ContentId): bool = let distance = p.routingTable.distance(nodeId, contentId) distance <= nodeRadius @@ -414,6 +419,7 @@ proc new*(T: type PortalProtocol, protocolId: PortalProtocolId, contentDB: ContentDB, toContentId: ToContentIdHandler, + validateContent: ContentValidationHandler, dataRadius = UInt256.high(), bootstrapRecords: openArray[Record] = [], distanceCalculator: DistanceCalculator = XorDistanceCalculator, @@ -429,6 +435,7 @@ proc new*(T: type PortalProtocol, baseProtocol: baseProtocol, contentDB: contentDB, toContentId: toContentId, + validateContent: validateContent, dataRadius: dataRadius, bootstrapRecords: @bootstrapRecords, radiusCache: RadiusCache.init(256), @@ -612,7 +619,7 @@ proc findContent*(p: PortalProtocol, dst: Node, contentKey: ByteList): else: return err("Content message returned invalid ENRs") -proc getContentKeys(o: OfferRequest): ContentKeysList = +proc getContentKeys(o: OfferRequest): ContentKeysList = case o.kind of Direct: var contentKeys:ContentKeysList @@ -640,7 +647,7 @@ proc offer(p: PortalProtocol, o: OfferRequest): ## to many peers, and keeping it all in memory could exhaust node resources. ## Main drawback is that content may be deleted from the node database ## by the cleanup process before it will be transferred, so this way does not - ## guarantee content transfer + ## guarantee content transfer. let contentKeys = getContentKeys(o) let acceptMessageResponse = await p.offerImpl(o.dst, contentKeys) @@ -716,7 +723,7 @@ proc offer*(p: PortalProtocol, dst: Node, content: seq[ContentInfo]): Future[PortalResult[void]] {.async.} = if len(content) > contentKeysLimit: return err("Cannot offer more than 64 content items") - + let contentList = List[ContentInfo, contentKeysLimit].init(content) let req = OfferRequest(dst: dst, kind: Direct, contentList: contentList) let res = await p.offer(req) @@ -759,24 +766,26 @@ proc processContent( {.gcsafe, raises: [Defect].} = let p = getUserData[PortalProtocol](stream) - # TODO: validate content - # - check amount of content items according to ContentKeysList - # - History Network specific: each content item, if header, check hash: - # this part of thevalidation will be specific per network & type and should - # be thus be custom per network + # TODO: + # - Implement a way to discern different content items (e.g. length prefixed) + # - Check amount of content items according to ContentKeysList + # - The above could also live in `PortalStream` # TODO: for now we only consider 1 item being offered if contentKeys.len() == 1: let contentKey = contentKeys[0] - let contentIdOpt = p.toContentId(contentKey) - if contentIdOpt.isNone(): - return + if p.validateContent(content, contentKey): + let contentIdOpt = p.toContentId(contentKey) + if contentIdOpt.isNone(): + return - let contentId = contentIdOpt.get() - # Store content, should we recheck radius? - p.contentDB.put(contentId, content) + let contentId = contentIdOpt.get() + # Store content, should we recheck radius? + p.contentDB.put(contentId, content) - asyncSpawn neighborhoodGossip(p, contentKeys) + asyncSpawn neighborhoodGossip(p, contentKeys) + else: + error "Received invalid content", contentKey proc lookupWorker( p: PortalProtocol, dst: Node, target: NodeId): Future[seq[Node]] {.async.} = @@ -854,7 +863,7 @@ proc lookup*(p: PortalProtocol, target: NodeId): Future[seq[Node]] {.async.} = proc triggerPoke*( p: PortalProtocol, - nodes: seq[Node], + nodes: seq[Node], contentKey: ByteList, content: seq[byte]) = ## Triggers asynchronous offer-accept interaction to provided nodes. @@ -871,7 +880,7 @@ proc triggerPoke*( raiseAssert(e.msg) else: # offer queue full, do not start more offer offer-accept interactions - return + return # TODO ContentLookup and Lookup look almost exactly the same, also lookups in other # networks will probably be very similar. Extract lookup function to separate module @@ -947,7 +956,7 @@ proc contentLookup*(p: PortalProtocol, target: ByteList, targetId: UInt256): if closestNodes.len > BUCKET_SIZE: closestNodes.del(closestNodes.high()) - + of Content: # cancel any pending queries as we have find the content for f in pendingQueries: diff --git a/fluffy/network/wire/portal_stream.nim b/fluffy/network/wire/portal_stream.nim index ff00f7bab..02f0a31bd 100644 --- a/fluffy/network/wire/portal_stream.nim +++ b/fluffy/network/wire/portal_stream.nim @@ -75,12 +75,25 @@ type udata: pointer contentHandler: ContentHandlerCallback +proc pruneAllowedConnections(stream: PortalStream) = + # Prune requests and offers that didn't receive a connection request + # before `connectionTimeout`. + let now = Moment.now() + stream.contentRequests.keepIf(proc(x: ContentRequest): bool = + x.timeout > now) + stream.contentOffers.keepIf(proc(x: ContentOffer): bool = + x.timeout > now) + proc getUserData*[T](stream: PortalStream): T = ## Obtain user data stored in ``stream`` object. cast[T](stream.udata) proc addContentOffer*( stream: PortalStream, nodeId: NodeId, contentKeys: ContentKeysList): Bytes2 = + stream.pruneAllowedConnections() + + # TODO: Should we check if `NodeId` & `connectionId` combo already exists? + # What happens if we get duplicates? var connectionId: Bytes2 brHmacDrbgGenerate(stream.rng[], connectionId) @@ -97,6 +110,10 @@ proc addContentOffer*( proc addContentRequest*( stream: PortalStream, nodeId: NodeId, content: seq[byte]): Bytes2 = + stream.pruneAllowedConnections() + + # TODO: Should we check if `NodeId` & `connectionId` combo already exists? + # What happens if we get duplicates? var connectionId: Bytes2 brHmacDrbgGenerate(stream.rng[], connectionId) @@ -194,15 +211,6 @@ proc new*( func setTransport*(stream: PortalStream, transport: UtpDiscv5Protocol) = stream.transport = transport -proc pruneAllowedConnections(stream: PortalStream) = - # Prune requests and offers that didn't receive a connection request - # before `connectionTimeout`. - let now = Moment.now() - stream.contentRequests.keepIf(proc(x: ContentRequest): bool = - x.timeout > now) - stream.contentOffers.keepIf(proc(x: ContentOffer): bool = - x.timeout > now) - # TODO: I think I'd like it more if we weren't to capture the stream. proc registerIncomingSocketCallback*( streams: seq[PortalStream]): AcceptConnectionCallback[NodeAddress] = diff --git a/fluffy/tests/test_portal_wire_protocol.nim b/fluffy/tests/test_portal_wire_protocol.nim index cbe8f69d2..a44983032 100644 --- a/fluffy/tests/test_portal_wire_protocol.nim +++ b/fluffy/tests/test_portal_wire_protocol.nim @@ -39,6 +39,9 @@ proc testHandlerSha256(contentKey: ByteList): Option[ContentId] = let idHash = sha256.digest(contentKey.asSeq()) some(readUintBE[256](idHash.data)) +proc validateContent(content: openArray[byte], contentKey: ByteList): bool = + true + proc defaultTestCase(rng: ref BrHmacDrbgContext): Default2NodeTest = let node1 = initDiscoveryNode( @@ -49,8 +52,10 @@ proc defaultTestCase(rng: ref BrHmacDrbgContext): Default2NodeTest = db1 = ContentDB.new("", inMemory = true) db2 = ContentDB.new("", inMemory = true) - proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler) - proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler) + proto1 = + PortalProtocol.new(node1, protocolId, db1, testHandler, validateContent) + proto2 = + PortalProtocol.new(node2, protocolId, db2, testHandler, validateContent) Default2NodeTest(node1: node1, node2: node2, proto1: proto1, proto2: proto2) @@ -206,9 +211,12 @@ procSuite "Portal Wire Protocol Tests": db2 = ContentDB.new("", inMemory = true) db3 = ContentDB.new("", inMemory = true) - proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler) - proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler) - proto3 = PortalProtocol.new(node3, protocolId, db3, testHandler) + proto1 = PortalProtocol.new( + node1, protocolId, db1, testHandler, validateContent) + proto2 = PortalProtocol.new( + node2, protocolId, db2, testHandler, validateContent) + proto3 = PortalProtocol.new( + node3, protocolId, db3, testHandler, validateContent) # Node1 knows about Node2, and Node2 knows about Node3 which hold all content check proto1.addNode(node2.localNode) == Added @@ -239,14 +247,17 @@ procSuite "Portal Wire Protocol Tests": db2 = ContentDB.new("", inMemory = true) db3 = ContentDB.new("", inMemory = true) - proto1 = PortalProtocol.new(node1, protocolId, db1, testHandlerSha256) - proto2 = PortalProtocol.new(node2, protocolId, db2, testHandlerSha256) - proto3 = PortalProtocol.new(node3, protocolId, db3, testHandlerSha256) + proto1 = PortalProtocol.new( + node1, protocolId, db1, testHandlerSha256, validateContent) + proto2 = PortalProtocol.new( + node2, protocolId, db2, testHandlerSha256, validateContent) + proto3 = PortalProtocol.new( + node3, protocolId, db3, testHandlerSha256, validateContent) content = @[byte 1, 2] contentList = List[byte, 2048].init(content) contentId = readUintBE[256](sha256.digest(content).data) - + # Only node3 have content db3.put(contentId, content) @@ -283,8 +294,10 @@ procSuite "Portal Wire Protocol Tests": db1 = ContentDB.new("", inMemory = true) db2 = ContentDB.new("", inMemory = true) - proto1 = PortalProtocol.new(node1, protocolId, db1, testHandler) - proto2 = PortalProtocol.new(node2, protocolId, db2, testHandler, + proto1 = PortalProtocol.new( + node1, protocolId, db1, testHandler, validateContent) + proto2 = PortalProtocol.new( + node2, protocolId, db2, testHandler, validateContent, bootstrapRecords = [node1.localNode.record]) proto1.start() @@ -307,7 +320,7 @@ procSuite "Portal Wire Protocol Tests": db = ContentDB.new("", inMemory = true) # No portal protocol for node1, hence an invalid bootstrap node proto2 = PortalProtocol.new(node2, protocolId, db, testHandler, - bootstrapRecords = [node1.localNode.record]) + validateContent, bootstrapRecords = [node1.localNode.record]) # seedTable to add node1 to the routing table proto2.seedTable() diff --git a/fluffy/tools/portalcli.nim b/fluffy/tools/portalcli.nim index 61a8c9bbb..f46313cde 100644 --- a/fluffy/tools/portalcli.nim +++ b/fluffy/tools/portalcli.nim @@ -187,6 +187,9 @@ proc testHandler(contentKey: ByteList): Option[ContentId] = let idHash = sha256.digest("test") some(readUintBE[256](idHash.data)) +proc validateContent(content: openArray[byte], contentKey: ByteList): bool = + true + proc run(config: PortalCliConf) = let rng = newRng() @@ -212,7 +215,8 @@ proc run(config: PortalCliConf) = let db = ContentDB.new("", inMemory = true) - portal = PortalProtocol.new(d, config.protocolId, db, testHandler, + portal = PortalProtocol.new(d, config.protocolId, db, + testHandler, validateContent, bootstrapRecords = bootstrapRecords) socketConfig = SocketConfig.init( incomingSocketReceiveTimeout = none(Duration))