diff --git a/fluffy/network/wire/portal_protocol.nim b/fluffy/network/wire/portal_protocol.nim index c19b4660e..5c0042626 100644 --- a/fluffy/network/wire/portal_protocol.nim +++ b/fluffy/network/wire/portal_protocol.nim @@ -408,7 +408,7 @@ proc handleFindContent( ) # Check first if content is in range, as this is a cheaper operation - if p.inRange(contentId) and p.stream.canAddPendingTransfer(srcId, contentId): + if p.inRange(contentId): let contentResult = p.dbGet(fc.contentKey, contentId) if contentResult.isOk(): let content = contentResult.get() @@ -419,8 +419,7 @@ proc handleFindContent( ) ) else: - p.stream.addPendingTransfer(srcId, contentId) - let connectionId = p.stream.addContentRequest(srcId, contentId, content) + let connectionId = p.stream.addContentRequest(srcId, content) return encodeMessage( ContentMessage( @@ -449,10 +448,8 @@ proc handleOffer(p: PortalProtocol, o: OfferMessage, srcId: NodeId): seq[byte] = ) ) - var - contentKeysBitList = ContentKeysBitList.init(o.contentKeys.len) - contentKeys = ContentKeysList.init(@[]) - contentIds = newSeq[ContentId]() + var contentKeysBitList = ContentKeysBitList.init(o.contentKeys.len) + var contentKeys = ContentKeysList.init(@[]) # TODO: Do we need some protection against a peer offering lots (64x) of # content that fits our Radius but is actually bogus? # Additional TODO, but more of a specification clarification: What if we don't @@ -468,19 +465,17 @@ proc handleOffer(p: PortalProtocol, o: OfferMessage, srcId: NodeId): seq[byte] = int64(logDistance), labelValues = [$p.protocolId] ) - if p.inRange(contentId) and p.stream.canAddPendingTransfer(srcId, contentId) and - not p.dbContains(contentKey, contentId): - p.stream.addPendingTransfer(srcId, contentId) - contentKeysBitList.setBit(i) - discard contentKeys.add(contentKey) - contentIds.add(contentId) + if p.inRange(contentId): + if not p.dbContains(contentKey, contentId): + contentKeysBitList.setBit(i) + discard contentKeys.add(contentKey) else: # Return empty response when content key validation fails return @[] let connectionId = if contentKeysBitList.countOnes() != 0: - p.stream.addContentOffer(srcId, contentKeys, contentIds) + p.stream.addContentOffer(srcId, contentKeys) else: # When the node does not accept any of the content offered, reply with an # all zeroes bitlist and connectionId. diff --git a/fluffy/network/wire/portal_stream.nim b/fluffy/network/wire/portal_stream.nim index 941ed65c6..13a68158a 100644 --- a/fluffy/network/wire/portal_stream.nim +++ b/fluffy/network/wire/portal_stream.nim @@ -8,7 +8,7 @@ {.push raises: [].} import - std/[sequtils, sets], + std/sequtils, chronos, stew/[byteutils, leb128, endians2], chronicles, @@ -35,20 +35,17 @@ const talkReqOverhead = getTalkReqOverhead(utpProtocolId) utpHeaderOverhead = 20 maxUtpPayloadSize = maxDiscv5PacketSize - talkReqOverhead - utpHeaderOverhead - maxPendingTransfersPerPeer = 128 type ContentRequest = object connectionId: uint16 nodeId: NodeId - contentId: ContentId content: seq[byte] timeout: Moment ContentOffer = object connectionId: uint16 nodeId: NodeId - contentIds: seq[ContentId] contentKeys: ContentKeysList timeout: Moment @@ -72,7 +69,6 @@ type connectionTimeout: Duration contentReadTimeout*: Duration rng: ref HmacDrbgContext - pendingTransfers: TableRef[NodeId, HashSet[ContentId]] contentQueue*: AsyncQueue[(Opt[NodeId], ContentKeysList, seq[seq[byte]])] StreamManager* = ref object @@ -80,93 +76,21 @@ type streams: seq[PortalStream] rng: ref HmacDrbgContext -proc canAddPendingTransfer( - transfers: TableRef[NodeId, HashSet[ContentId]], - nodeId: NodeId, - contentId: ContentId, - limit: int, -): bool = - if not transfers.contains(nodeId): - return true - - try: - let contentIds = transfers[nodeId] - if (contentIds.len() < limit) and not contentIds.contains(contentId): - return true - else: - debug "Pending transfer limit reached for peer", nodeId, contentId - return false - except KeyError as e: - raiseAssert(e.msg) - -proc addPendingTransfer( - transfers: TableRef[NodeId, HashSet[ContentId]], - nodeId: NodeId, - contentId: ContentId, -) = - if transfers.contains(nodeId): - try: - transfers[nodeId].incl(contentId) - except KeyError as e: - raiseAssert(e.msg) - else: - var contentIds = initHashSet[ContentId]() - contentIds.incl(contentId) - transfers[nodeId] = contentIds - -proc removePendingTransfer( - transfers: TableRef[NodeId, HashSet[ContentId]], - nodeId: NodeId, - contentId: ContentId, -) = - doAssert transfers.contains(nodeId) - - try: - transfers[nodeId].excl(contentId) - - if transfers[nodeId].len() == 0: - transfers.del(nodeId) - except KeyError as e: - raiseAssert(e.msg) - -template canAddPendingTransfer*( - stream: PortalStream, nodeId: NodeId, contentId: ContentId -): bool = - stream.pendingTransfers.canAddPendingTransfer( - srcId, contentId, maxPendingTransfersPerPeer - ) - -template addPendingTransfer*( - stream: PortalStream, nodeId: NodeId, contentId: ContentId -) = - addPendingTransfer(stream.pendingTransfers, nodeId, contentId) - -template removePendingTransfer*( - stream: PortalStream, nodeId: NodeId, contentId: ContentId -) = - removePendingTransfer(stream.pendingTransfers, nodeId, contentId) - proc pruneAllowedConnections(stream: PortalStream) = # Prune requests and offers that didn't receive a connection request # before `connectionTimeout`. let now = Moment.now() - - for i, request in stream.contentRequests: - if request.timeout <= now: - stream.removePendingTransfer(request.nodeId, request.contentId) - stream.contentRequests.del(i) - - for i, offer in stream.contentOffers: - if offer.timeout <= now: - for contentId in offer.contentIds: - stream.removePendingTransfer(offer.nodeId, contentId) - stream.contentOffers.del(i) + stream.contentRequests.keepIf( + proc(x: ContentRequest): bool = + x.timeout > now + ) + stream.contentOffers.keepIf( + proc(x: ContentOffer): bool = + x.timeout > now + ) proc addContentOffer*( - stream: PortalStream, - nodeId: NodeId, - contentKeys: ContentKeysList, - contentIds: seq[ContentId], + stream: PortalStream, nodeId: NodeId, contentKeys: ContentKeysList ): Bytes2 = stream.pruneAllowedConnections() @@ -183,7 +107,6 @@ proc addContentOffer*( let contentOffer = ContentOffer( connectionId: id, nodeId: nodeId, - contentIds: contentIds, contentKeys: contentKeys, timeout: Moment.now() + stream.connectionTimeout, ) @@ -192,7 +115,7 @@ proc addContentOffer*( return connectionId proc addContentRequest*( - stream: PortalStream, nodeId: NodeId, contentId: ContentId, content: seq[byte] + stream: PortalStream, nodeId: NodeId, content: seq[byte] ): Bytes2 = stream.pruneAllowedConnections() @@ -206,7 +129,6 @@ proc addContentRequest*( let contentRequest = ContentRequest( connectionId: id, nodeId: nodeId, - contentId: contentId, content: content, timeout: Moment.now() + stream.connectionTimeout, ) @@ -363,7 +285,6 @@ proc new( transport: transport, connectionTimeout: connectionTimeout, contentReadTimeout: contentReadTimeout, - pendingTransfers: newTable[NodeId, HashSet[ContentId]](), contentQueue: contentQueue, rng: rng, ) @@ -396,8 +317,6 @@ proc handleIncomingConnection( if request.connectionId == socket.connectionId and request.nodeId == socket.remoteAddress.nodeId: let fut = socket.writeContentRequest(stream, request) - - stream.removePendingTransfer(request.nodeId, request.contentId) stream.contentRequests.del(i) return noCancel(fut) @@ -405,9 +324,6 @@ proc handleIncomingConnection( if offer.connectionId == socket.connectionId and offer.nodeId == socket.remoteAddress.nodeId: let fut = socket.readContentOffer(stream, offer) - - for contentId in offer.contentIds: - stream.removePendingTransfer(offer.nodeId, contentId) stream.contentOffers.del(i) return noCancel(fut)