Revert commit 6142183 and partial of b446d2a (#2898)

There is an assertion hitting due to the additon of an iterator
that deletes items from the sequence while iteratting over it.
Before the keepIf helper was used that has different code for
doing this similar work.
This commit is contained in:
Kim De Mey 2024-12-02 20:09:58 +07:00 committed by GitHub
parent dd888deadb
commit 0f18de61dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 109 deletions

View File

@ -408,7 +408,7 @@ proc handleFindContent(
) )
# Check first if content is in range, as this is a cheaper operation # Check first if content is in range, as this is a cheaper operation
if p.inRange(contentId) and p.stream.canAddPendingTransfer(srcId, contentId): if p.inRange(contentId):
let contentResult = p.dbGet(fc.contentKey, contentId) let contentResult = p.dbGet(fc.contentKey, contentId)
if contentResult.isOk(): if contentResult.isOk():
let content = contentResult.get() let content = contentResult.get()
@ -419,8 +419,7 @@ proc handleFindContent(
) )
) )
else: else:
p.stream.addPendingTransfer(srcId, contentId) let connectionId = p.stream.addContentRequest(srcId, content)
let connectionId = p.stream.addContentRequest(srcId, contentId, content)
return encodeMessage( return encodeMessage(
ContentMessage( ContentMessage(
@ -449,10 +448,8 @@ proc handleOffer(p: PortalProtocol, o: OfferMessage, srcId: NodeId): seq[byte] =
) )
) )
var var contentKeysBitList = ContentKeysBitList.init(o.contentKeys.len)
contentKeysBitList = ContentKeysBitList.init(o.contentKeys.len) var contentKeys = ContentKeysList.init(@[])
contentKeys = ContentKeysList.init(@[])
contentIds = newSeq[ContentId]()
# TODO: Do we need some protection against a peer offering lots (64x) of # TODO: Do we need some protection against a peer offering lots (64x) of
# content that fits our Radius but is actually bogus? # content that fits our Radius but is actually bogus?
# Additional TODO, but more of a specification clarification: What if we don't # Additional TODO, but more of a specification clarification: What if we don't
@ -468,19 +465,17 @@ proc handleOffer(p: PortalProtocol, o: OfferMessage, srcId: NodeId): seq[byte] =
int64(logDistance), labelValues = [$p.protocolId] int64(logDistance), labelValues = [$p.protocolId]
) )
if p.inRange(contentId) and p.stream.canAddPendingTransfer(srcId, contentId) and if p.inRange(contentId):
not p.dbContains(contentKey, contentId): if not p.dbContains(contentKey, contentId):
p.stream.addPendingTransfer(srcId, contentId)
contentKeysBitList.setBit(i) contentKeysBitList.setBit(i)
discard contentKeys.add(contentKey) discard contentKeys.add(contentKey)
contentIds.add(contentId)
else: else:
# Return empty response when content key validation fails # Return empty response when content key validation fails
return @[] return @[]
let connectionId = let connectionId =
if contentKeysBitList.countOnes() != 0: if contentKeysBitList.countOnes() != 0:
p.stream.addContentOffer(srcId, contentKeys, contentIds) p.stream.addContentOffer(srcId, contentKeys)
else: else:
# When the node does not accept any of the content offered, reply with an # When the node does not accept any of the content offered, reply with an
# all zeroes bitlist and connectionId. # all zeroes bitlist and connectionId.

View File

@ -8,7 +8,7 @@
{.push raises: [].} {.push raises: [].}
import import
std/[sequtils, sets], std/sequtils,
chronos, chronos,
stew/[byteutils, leb128, endians2], stew/[byteutils, leb128, endians2],
chronicles, chronicles,
@ -35,20 +35,17 @@ const
talkReqOverhead = getTalkReqOverhead(utpProtocolId) talkReqOverhead = getTalkReqOverhead(utpProtocolId)
utpHeaderOverhead = 20 utpHeaderOverhead = 20
maxUtpPayloadSize = maxDiscv5PacketSize - talkReqOverhead - utpHeaderOverhead maxUtpPayloadSize = maxDiscv5PacketSize - talkReqOverhead - utpHeaderOverhead
maxPendingTransfersPerPeer = 128
type type
ContentRequest = object ContentRequest = object
connectionId: uint16 connectionId: uint16
nodeId: NodeId nodeId: NodeId
contentId: ContentId
content: seq[byte] content: seq[byte]
timeout: Moment timeout: Moment
ContentOffer = object ContentOffer = object
connectionId: uint16 connectionId: uint16
nodeId: NodeId nodeId: NodeId
contentIds: seq[ContentId]
contentKeys: ContentKeysList contentKeys: ContentKeysList
timeout: Moment timeout: Moment
@ -72,7 +69,6 @@ type
connectionTimeout: Duration connectionTimeout: Duration
contentReadTimeout*: Duration contentReadTimeout*: Duration
rng: ref HmacDrbgContext rng: ref HmacDrbgContext
pendingTransfers: TableRef[NodeId, HashSet[ContentId]]
contentQueue*: AsyncQueue[(Opt[NodeId], ContentKeysList, seq[seq[byte]])] contentQueue*: AsyncQueue[(Opt[NodeId], ContentKeysList, seq[seq[byte]])]
StreamManager* = ref object StreamManager* = ref object
@ -80,93 +76,21 @@ type
streams: seq[PortalStream] streams: seq[PortalStream]
rng: ref HmacDrbgContext rng: ref HmacDrbgContext
proc canAddPendingTransfer(
transfers: TableRef[NodeId, HashSet[ContentId]],
nodeId: NodeId,
contentId: ContentId,
limit: int,
): bool =
if not transfers.contains(nodeId):
return true
try:
let contentIds = transfers[nodeId]
if (contentIds.len() < limit) and not contentIds.contains(contentId):
return true
else:
debug "Pending transfer limit reached for peer", nodeId, contentId
return false
except KeyError as e:
raiseAssert(e.msg)
proc addPendingTransfer(
transfers: TableRef[NodeId, HashSet[ContentId]],
nodeId: NodeId,
contentId: ContentId,
) =
if transfers.contains(nodeId):
try:
transfers[nodeId].incl(contentId)
except KeyError as e:
raiseAssert(e.msg)
else:
var contentIds = initHashSet[ContentId]()
contentIds.incl(contentId)
transfers[nodeId] = contentIds
proc removePendingTransfer(
transfers: TableRef[NodeId, HashSet[ContentId]],
nodeId: NodeId,
contentId: ContentId,
) =
doAssert transfers.contains(nodeId)
try:
transfers[nodeId].excl(contentId)
if transfers[nodeId].len() == 0:
transfers.del(nodeId)
except KeyError as e:
raiseAssert(e.msg)
template canAddPendingTransfer*(
stream: PortalStream, nodeId: NodeId, contentId: ContentId
): bool =
stream.pendingTransfers.canAddPendingTransfer(
srcId, contentId, maxPendingTransfersPerPeer
)
template addPendingTransfer*(
stream: PortalStream, nodeId: NodeId, contentId: ContentId
) =
addPendingTransfer(stream.pendingTransfers, nodeId, contentId)
template removePendingTransfer*(
stream: PortalStream, nodeId: NodeId, contentId: ContentId
) =
removePendingTransfer(stream.pendingTransfers, nodeId, contentId)
proc pruneAllowedConnections(stream: PortalStream) = proc pruneAllowedConnections(stream: PortalStream) =
# Prune requests and offers that didn't receive a connection request # Prune requests and offers that didn't receive a connection request
# before `connectionTimeout`. # before `connectionTimeout`.
let now = Moment.now() let now = Moment.now()
stream.contentRequests.keepIf(
for i, request in stream.contentRequests: proc(x: ContentRequest): bool =
if request.timeout <= now: x.timeout > now
stream.removePendingTransfer(request.nodeId, request.contentId) )
stream.contentRequests.del(i) stream.contentOffers.keepIf(
proc(x: ContentOffer): bool =
for i, offer in stream.contentOffers: x.timeout > now
if offer.timeout <= now: )
for contentId in offer.contentIds:
stream.removePendingTransfer(offer.nodeId, contentId)
stream.contentOffers.del(i)
proc addContentOffer*( proc addContentOffer*(
stream: PortalStream, stream: PortalStream, nodeId: NodeId, contentKeys: ContentKeysList
nodeId: NodeId,
contentKeys: ContentKeysList,
contentIds: seq[ContentId],
): Bytes2 = ): Bytes2 =
stream.pruneAllowedConnections() stream.pruneAllowedConnections()
@ -183,7 +107,6 @@ proc addContentOffer*(
let contentOffer = ContentOffer( let contentOffer = ContentOffer(
connectionId: id, connectionId: id,
nodeId: nodeId, nodeId: nodeId,
contentIds: contentIds,
contentKeys: contentKeys, contentKeys: contentKeys,
timeout: Moment.now() + stream.connectionTimeout, timeout: Moment.now() + stream.connectionTimeout,
) )
@ -192,7 +115,7 @@ proc addContentOffer*(
return connectionId return connectionId
proc addContentRequest*( proc addContentRequest*(
stream: PortalStream, nodeId: NodeId, contentId: ContentId, content: seq[byte] stream: PortalStream, nodeId: NodeId, content: seq[byte]
): Bytes2 = ): Bytes2 =
stream.pruneAllowedConnections() stream.pruneAllowedConnections()
@ -206,7 +129,6 @@ proc addContentRequest*(
let contentRequest = ContentRequest( let contentRequest = ContentRequest(
connectionId: id, connectionId: id,
nodeId: nodeId, nodeId: nodeId,
contentId: contentId,
content: content, content: content,
timeout: Moment.now() + stream.connectionTimeout, timeout: Moment.now() + stream.connectionTimeout,
) )
@ -363,7 +285,6 @@ proc new(
transport: transport, transport: transport,
connectionTimeout: connectionTimeout, connectionTimeout: connectionTimeout,
contentReadTimeout: contentReadTimeout, contentReadTimeout: contentReadTimeout,
pendingTransfers: newTable[NodeId, HashSet[ContentId]](),
contentQueue: contentQueue, contentQueue: contentQueue,
rng: rng, rng: rng,
) )
@ -396,8 +317,6 @@ proc handleIncomingConnection(
if request.connectionId == socket.connectionId and if request.connectionId == socket.connectionId and
request.nodeId == socket.remoteAddress.nodeId: request.nodeId == socket.remoteAddress.nodeId:
let fut = socket.writeContentRequest(stream, request) let fut = socket.writeContentRequest(stream, request)
stream.removePendingTransfer(request.nodeId, request.contentId)
stream.contentRequests.del(i) stream.contentRequests.del(i)
return noCancel(fut) return noCancel(fut)
@ -405,9 +324,6 @@ proc handleIncomingConnection(
if offer.connectionId == socket.connectionId and if offer.connectionId == socket.connectionId and
offer.nodeId == socket.remoteAddress.nodeId: offer.nodeId == socket.remoteAddress.nodeId:
let fut = socket.readContentOffer(stream, offer) let fut = socket.readContentOffer(stream, offer)
for contentId in offer.contentIds:
stream.removePendingTransfer(offer.nodeId, contentId)
stream.contentOffers.del(i) stream.contentOffers.del(i)
return noCancel(fut) return noCancel(fut)