Fluffy: Limit concurrent offers that can be received from each peer (#2885)

* Limit offer transfers per peer.

* Remove pending transfers in prune.

* Limit content lookups.

* Improve performance of canAddPendingTransfer and addPendingTransfer.
This commit is contained in:
bhartnett 2024-11-29 11:31:46 +08:00 committed by GitHub
parent c0199e8944
commit 6142183d2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 105 additions and 20 deletions

View File

@ -408,7 +408,7 @@ proc handleFindContent(
)
# Check first if content is in range, as this is a cheaper operation
if p.inRange(contentId):
if p.inRange(contentId) and p.stream.canAddPendingTransfer(srcId, contentId):
let contentResult = p.dbGet(fc.contentKey, contentId)
if contentResult.isOk():
let content = contentResult.get()
@ -419,7 +419,8 @@ proc handleFindContent(
)
)
else:
let connectionId = p.stream.addContentRequest(srcId, content)
p.stream.addPendingTransfer(srcId, contentId)
let connectionId = p.stream.addContentRequest(srcId, contentId, content)
return encodeMessage(
ContentMessage(
@ -448,8 +449,10 @@ proc handleOffer(p: PortalProtocol, o: OfferMessage, srcId: NodeId): seq[byte] =
)
)
var contentKeysBitList = ContentKeysBitList.init(o.contentKeys.len)
var contentKeys = ContentKeysList.init(@[])
var
contentKeysBitList = ContentKeysBitList.init(o.contentKeys.len)
contentKeys = ContentKeysList.init(@[])
contentIds = newSeq[ContentId]()
# TODO: Do we need some protection against a peer offering lots (64x) of
# content that fits our Radius but is actually bogus?
# Additional TODO, but more of a specification clarification: What if we don't
@ -465,17 +468,19 @@ proc handleOffer(p: PortalProtocol, o: OfferMessage, srcId: NodeId): seq[byte] =
int64(logDistance), labelValues = [$p.protocolId]
)
if p.inRange(contentId):
if not p.dbContains(contentKey, contentId):
if p.inRange(contentId) and p.stream.canAddPendingTransfer(srcId, contentId) and
not p.dbContains(contentKey, contentId):
p.stream.addPendingTransfer(srcId, contentId)
contentKeysBitList.setBit(i)
discard contentKeys.add(contentKey)
contentIds.add(contentId)
else:
# Return empty response when content key validation fails
return @[]
let connectionId =
if contentKeysBitList.countOnes() != 0:
p.stream.addContentOffer(srcId, contentKeys)
p.stream.addContentOffer(srcId, contentKeys, contentIds)
else:
# When the node does not accept any of the content offered, reply with an
# all zeroes bitlist and connectionId.

View File

@ -8,7 +8,7 @@
{.push raises: [].}
import
std/sequtils,
std/[sequtils, sets],
chronos,
stew/[byteutils, leb128, endians2],
chronicles,
@ -35,17 +35,20 @@ const
talkReqOverhead = getTalkReqOverhead(utpProtocolId)
utpHeaderOverhead = 20
maxUtpPayloadSize = maxDiscv5PacketSize - talkReqOverhead - utpHeaderOverhead
maxPendingTransfersPerPeer = 128
type
ContentRequest = object
connectionId: uint16
nodeId: NodeId
contentId: ContentId
content: seq[byte]
timeout: Moment
ContentOffer = object
connectionId: uint16
nodeId: NodeId
contentIds: seq[ContentId]
contentKeys: ContentKeysList
timeout: Moment
@ -69,6 +72,7 @@ type
connectionTimeout: Duration
contentReadTimeout*: Duration
rng: ref HmacDrbgContext
pendingTransfers: TableRef[NodeId, HashSet[ContentId]]
contentQueue*: AsyncQueue[(Opt[NodeId], ContentKeysList, seq[seq[byte]])]
StreamManager* = ref object
@ -76,21 +80,89 @@ type
streams: seq[PortalStream]
rng: ref HmacDrbgContext
proc canAddPendingTransfer(
transfers: TableRef[NodeId, HashSet[ContentId]],
nodeId: NodeId,
contentId: ContentId,
limit: int,
): bool =
if not transfers.contains(nodeId):
return true
try:
let contentIds = transfers[nodeId]
(contentIds.len() < limit) and not contentIds.contains(contentId)
except KeyError as e:
raiseAssert(e.msg)
proc addPendingTransfer(
transfers: TableRef[NodeId, HashSet[ContentId]],
nodeId: NodeId,
contentId: ContentId,
) =
if transfers.contains(nodeId):
try:
transfers[nodeId].incl(contentId)
except KeyError as e:
raiseAssert(e.msg)
else:
var contentIds = initHashSet[ContentId]()
contentIds.incl(contentId)
transfers[nodeId] = contentIds
proc removePendingTransfer(
transfers: TableRef[NodeId, HashSet[ContentId]],
nodeId: NodeId,
contentId: ContentId,
) =
doAssert transfers.contains(nodeId)
try:
transfers[nodeId].excl(contentId)
if transfers[nodeId].len() == 0:
transfers.del(nodeId)
except KeyError as e:
raiseAssert(e.msg)
template canAddPendingTransfer*(
stream: PortalStream, nodeId: NodeId, contentId: ContentId
): bool =
stream.pendingTransfers.canAddPendingTransfer(
srcId, contentId, maxPendingTransfersPerPeer
)
template addPendingTransfer*(
stream: PortalStream, nodeId: NodeId, contentId: ContentId
) =
addPendingTransfer(stream.pendingTransfers, nodeId, contentId)
template removePendingTransfer*(
stream: PortalStream, nodeId: NodeId, contentId: ContentId
) =
removePendingTransfer(stream.pendingTransfers, nodeId, contentId)
proc pruneAllowedConnections(stream: PortalStream) =
# Prune requests and offers that didn't receive a connection request
# before `connectionTimeout`.
let now = Moment.now()
stream.contentRequests.keepIf(
proc(x: ContentRequest): bool =
x.timeout > now
)
stream.contentOffers.keepIf(
proc(x: ContentOffer): bool =
x.timeout > now
)
for i, request in stream.contentRequests:
if request.timeout <= now:
stream.removePendingTransfer(request.nodeId, request.contentId)
stream.contentRequests.del(i)
for i, offer in stream.contentOffers:
if offer.timeout <= now:
for contentId in offer.contentIds:
stream.removePendingTransfer(offer.nodeId, contentId)
stream.contentOffers.del(i)
proc addContentOffer*(
stream: PortalStream, nodeId: NodeId, contentKeys: ContentKeysList
stream: PortalStream,
nodeId: NodeId,
contentKeys: ContentKeysList,
contentIds: seq[ContentId],
): Bytes2 =
stream.pruneAllowedConnections()
@ -107,6 +179,7 @@ proc addContentOffer*(
let contentOffer = ContentOffer(
connectionId: id,
nodeId: nodeId,
contentIds: contentIds,
contentKeys: contentKeys,
timeout: Moment.now() + stream.connectionTimeout,
)
@ -115,7 +188,7 @@ proc addContentOffer*(
return connectionId
proc addContentRequest*(
stream: PortalStream, nodeId: NodeId, content: seq[byte]
stream: PortalStream, nodeId: NodeId, contentId: ContentId, content: seq[byte]
): Bytes2 =
stream.pruneAllowedConnections()
@ -129,6 +202,7 @@ proc addContentRequest*(
let contentRequest = ContentRequest(
connectionId: id,
nodeId: nodeId,
contentId: contentId,
content: content,
timeout: Moment.now() + stream.connectionTimeout,
)
@ -285,6 +359,7 @@ proc new(
transport: transport,
connectionTimeout: connectionTimeout,
contentReadTimeout: contentReadTimeout,
pendingTransfers: newTable[NodeId, HashSet[ContentId]](),
contentQueue: contentQueue,
rng: rng,
)
@ -317,6 +392,8 @@ proc handleIncomingConnection(
if request.connectionId == socket.connectionId and
request.nodeId == socket.remoteAddress.nodeId:
let fut = socket.writeContentRequest(stream, request)
stream.removePendingTransfer(request.nodeId, request.contentId)
stream.contentRequests.del(i)
return noCancel(fut)
@ -324,6 +401,9 @@ proc handleIncomingConnection(
if offer.connectionId == socket.connectionId and
offer.nodeId == socket.remoteAddress.nodeId:
let fut = socket.readContentOffer(stream, offer)
for contentId in offer.contentIds:
stream.removePendingTransfer(offer.nodeId, contentId)
stream.contentOffers.del(i)
return noCancel(fut)