From 0c63ce4e9bd219092b7496d8b6b74b28b238ae3e Mon Sep 17 00:00:00 2001 From: Simon-Pierre Vivier Date: Thu, 24 Apr 2025 09:07:21 -0400 Subject: [PATCH] feat: refactor waku sync DOS protection (#3391) --- tests/waku_store_sync/sync_utils.nim | 6 ++--- tests/waku_store_sync/test_protocol.nim | 31 ++++++---------------- waku/node/waku_node.nim | 2 +- waku/waku_store_sync/reconciliation.nim | 18 +++++++------ waku/waku_store_sync/transfer.nim | 35 +++++++++---------------- 5 files changed, 34 insertions(+), 58 deletions(-) diff --git a/tests/waku_store_sync/sync_utils.nim b/tests/waku_store_sync/sync_utils.nim index a81ad6e2f..e7fd82b57 100644 --- a/tests/waku_store_sync/sync_utils.nim +++ b/tests/waku_store_sync/sync_utils.nim @@ -1,4 +1,4 @@ -import std/[options, random], chronos +import std/[options, random], chronos, chronicles import waku/[ @@ -23,7 +23,7 @@ proc randomHash*(rng: var Rand): WakuMessageHash = proc newTestWakuRecon*( switch: Switch, idsRx: AsyncQueue[SyncID], - wantsTx: AsyncQueue[(PeerId, Fingerprint)], + wantsTx: AsyncQueue[PeerId], needsTx: AsyncQueue[(PeerId, Fingerprint)], cluster: uint16 = 1, shards: seq[uint16] = @[0, 1, 2, 3, 4, 5, 6, 7], @@ -51,7 +51,7 @@ proc newTestWakuRecon*( proc newTestWakuTransfer*( switch: Switch, idsTx: AsyncQueue[SyncID], - wantsRx: AsyncQueue[(PeerId, Fingerprint)], + wantsRx: AsyncQueue[PeerId], needsRx: AsyncQueue[(PeerId, Fingerprint)], ): SyncTransfer = let peerManager = PeerManager.new(switch) diff --git a/tests/waku_store_sync/test_protocol.nim b/tests/waku_store_sync/test_protocol.nim index df14de6a1..efdd6a885 100644 --- a/tests/waku_store_sync/test_protocol.nim +++ b/tests/waku_store_sync/test_protocol.nim @@ -27,7 +27,7 @@ suite "Waku Sync: reconciliation": var idsChannel {.threadvar.}: AsyncQueue[SyncID] - localWants {.threadvar.}: AsyncQueue[(PeerId, WakuMessageHash)] + localWants {.threadvar.}: AsyncQueue[PeerId] remoteNeeds {.threadvar.}: AsyncQueue[(PeerId, WakuMessageHash)] var server {.threadvar.}: SyncReconciliation @@ -43,7 +43,7 @@ suite "Waku Sync: reconciliation": await allFutures(serverSwitch.start(), clientSwitch.start()) idsChannel = newAsyncQueue[SyncID]() - localWants = newAsyncQueue[(PeerId, WakuMessageHash)]() + localWants = newAsyncQueue[PeerId]() remoteNeeds = newAsyncQueue[(PeerId, WakuMessageHash)]() server = await newTestWakuRecon(serverSwitch, idsChannel, localWants, remoteNeeds) @@ -61,7 +61,6 @@ suite "Waku Sync: reconciliation": asyncTest "sync 2 nodes both empty": check: idsChannel.len == 0 - localWants.len == 0 remoteNeeds.len == 0 let res = await client.storeSynchronization(some(serverPeerInfo)) @@ -69,7 +68,6 @@ suite "Waku Sync: reconciliation": check: idsChannel.len == 0 - localWants.len == 0 remoteNeeds.len == 0 asyncTest "sync 2 nodes empty client full server": @@ -141,8 +139,6 @@ suite "Waku Sync: reconciliation": check: remoteNeeds.contains((serverPeerInfo.peerId, hash3)) == false remoteNeeds.contains((clientPeerInfo.peerId, hash2)) == false - localWants.contains((clientPeerInfo.peerId, hash3)) == false - localWants.contains((serverPeerInfo.peerId, hash2)) == false var syncRes = await client.storeSynchronization(some(serverPeerInfo)) assert syncRes.isOk(), $syncRes.error @@ -150,8 +146,6 @@ suite "Waku Sync: reconciliation": check: remoteNeeds.contains((serverPeerInfo.peerId, hash3)) == true remoteNeeds.contains((clientPeerInfo.peerId, hash2)) == true - localWants.contains((clientPeerInfo.peerId, hash3)) == true - localWants.contains((serverPeerInfo.peerId, hash2)) == true asyncTest "sync 2 nodes different shards": let @@ -170,8 +164,6 @@ suite "Waku Sync: reconciliation": check: remoteNeeds.contains((serverPeerInfo.peerId, hash3)) == false remoteNeeds.contains((clientPeerInfo.peerId, hash2)) == false - localWants.contains((clientPeerInfo.peerId, hash3)) == false - localWants.contains((serverPeerInfo.peerId, hash2)) == false server = await newTestWakuRecon( serverSwitch, idsChannel, localWants, remoteNeeds, shards = @[0.uint16, 1, 2, 3] @@ -185,7 +177,6 @@ suite "Waku Sync: reconciliation": check: remoteNeeds.len == 0 - localWants.len == 0 asyncTest "sync 2 nodes same hashes": let @@ -200,14 +191,12 @@ suite "Waku Sync: reconciliation": client.messageIngress(hash2, msg2) check: - localWants.len == 0 remoteNeeds.len == 0 let res = await client.storeSynchronization(some(serverPeerInfo)) assert res.isOk(), $res.error check: - localWants.len == 0 remoteNeeds.len == 0 asyncTest "sync 2 nodes 100K msgs 1 diff": @@ -236,14 +225,12 @@ suite "Waku Sync: reconciliation": timestamp += Timestamp(part) check: - localWants.contains((serverPeerInfo.peerId, WakuMessageHash(diff))) == false remoteNeeds.contains((clientPeerInfo.peerId, WakuMessageHash(diff))) == false let res = await client.storeSynchronization(some(serverPeerInfo)) assert res.isOk(), $res.error check: - localWants.contains((serverPeerInfo.peerId, WakuMessageHash(diff))) == true remoteNeeds.contains((clientPeerInfo.peerId, WakuMessageHash(diff))) == true asyncTest "sync 2 nodes 10K msgs 1K diffs": @@ -286,7 +273,6 @@ suite "Waku Sync: reconciliation": continue check: - localWants.len == 0 remoteNeeds.len == 0 let res = await client.storeSynchronization(some(serverPeerInfo)) @@ -294,7 +280,6 @@ suite "Waku Sync: reconciliation": # timimg issue make it hard to match exact numbers check: - localWants.len > 900 remoteNeeds.len > 900 suite "Waku Sync: transfer": @@ -310,10 +295,10 @@ suite "Waku Sync: transfer": var serverIds {.threadvar.}: AsyncQueue[SyncID] - serverLocalWants {.threadvar.}: AsyncQueue[(PeerId, WakuMessageHash)] + serverLocalWants {.threadvar.}: AsyncQueue[PeerId] serverRemoteNeeds {.threadvar.}: AsyncQueue[(PeerId, WakuMessageHash)] clientIds {.threadvar.}: AsyncQueue[SyncID] - clientLocalWants {.threadvar.}: AsyncQueue[(PeerId, WakuMessageHash)] + clientLocalWants {.threadvar.}: AsyncQueue[PeerId] clientRemoteNeeds {.threadvar.}: AsyncQueue[(PeerId, WakuMessageHash)] var @@ -341,7 +326,7 @@ suite "Waku Sync: transfer": clientPeerManager = PeerManager.new(clientSwitch) serverIds = newAsyncQueue[SyncID]() - serverLocalWants = newAsyncQueue[(PeerId, WakuMessageHash)]() + serverLocalWants = newAsyncQueue[PeerId]() serverRemoteNeeds = newAsyncQueue[(PeerId, WakuMessageHash)]() server = SyncTransfer.new( @@ -353,7 +338,7 @@ suite "Waku Sync: transfer": ) clientIds = newAsyncQueue[SyncID]() - clientLocalWants = newAsyncQueue[(PeerId, WakuMessageHash)]() + clientLocalWants = newAsyncQueue[PeerId]() clientRemoteNeeds = newAsyncQueue[(PeerId, WakuMessageHash)]() client = SyncTransfer.new( @@ -389,8 +374,8 @@ suite "Waku Sync: transfer": serverDriver = serverDriver.put(DefaultPubsubTopic, msgs) - # add server info and msg hash to client want channel - let want = (serverPeerInfo.peerId, hash) + # add server info to client want channel + let want = serverPeerInfo.peerId await clientLocalWants.put(want) # add client info and msg hash to server need channel diff --git a/waku/node/waku_node.nim b/waku/node/waku_node.nim index a544bdc80..ce86c3c57 100644 --- a/waku/node/waku_node.nim +++ b/waku/node/waku_node.nim @@ -212,7 +212,7 @@ proc mountStoreSync*( storeSyncRelayJitter = 20, ): Future[Result[void, string]] {.async.} = let idsChannel = newAsyncQueue[SyncID](0) - let wantsChannel = newAsyncQueue[(PeerId, WakuMessageHash)](0) + let wantsChannel = newAsyncQueue[PeerId](0) let needsChannel = newAsyncQueue[(PeerId, WakuMessageHash)](0) var cluster: uint16 diff --git a/waku/waku_store_sync/reconciliation.nim b/waku/waku_store_sync/reconciliation.nim index c08a9e434..d9912a3df 100644 --- a/waku/waku_store_sync/reconciliation.nim +++ b/waku/waku_store_sync/reconciliation.nim @@ -46,13 +46,10 @@ type SyncReconciliation* = ref object of LPProtocol storage: SyncStorage - # Receive IDs from transfer protocol for storage + # AsyncQueues are used as communication channels between + # reconciliation and transfer protocols. idsRx: AsyncQueue[SyncID] - - # Send Hashes to transfer protocol for reception - localWantsTx: AsyncQueue[(PeerId, WakuMessageHash)] - - # Send Hashes to transfer protocol for transmission + localWantsTx: AsyncQueue[PeerId] remoteNeedsTx: AsyncQueue[(PeerId, WakuMessageHash)] # params @@ -100,6 +97,9 @@ proc processRequest( roundTrips = 0 diffs = 0 + # Signal to transfer protocol that this reconciliation is starting + await self.localWantsTx.addLast(conn.peerId) + while true: let readRes = catch: await conn.readLp(int.high) @@ -143,7 +143,6 @@ proc processRequest( diffs.inc() for hash in hashToRecv: - self.localWantsTx.addLastNoWait((conn.peerId, hash)) diffs.inc() rawPayload = sendPayload.deltaEncode() @@ -168,6 +167,9 @@ proc processRequest( continue + # Signal to transfer protocol that this reconciliation is done + await self.localWantsTx.addLast(conn.peerId) + reconciliation_roundtrips.observe(roundTrips) reconciliation_differences.observe(diffs) @@ -296,7 +298,7 @@ proc new*( syncInterval: timer.Duration = DefaultSyncInterval, relayJitter: timer.Duration = DefaultGossipSubJitter, idsRx: AsyncQueue[SyncID], - localWantsTx: AsyncQueue[(PeerId, WakuMessageHash)], + localWantsTx: AsyncQueue[PeerId], remoteNeedsTx: AsyncQueue[(PeerId, WakuMessageHash)], ): Future[Result[T, string]] {.async.} = let res = await initFillStorage(syncRange, wakuArchive) diff --git a/waku/waku_store_sync/transfer.nim b/waku/waku_store_sync/transfer.nim index 5a52cac9c..c1e5d3e37 100644 --- a/waku/waku_store_sync/transfer.nim +++ b/waku/waku_store_sync/transfer.nim @@ -37,9 +37,9 @@ type SyncTransfer* = ref object of LPProtocol idsTx: AsyncQueue[SyncID] # Receive Hashes from reconciliation protocol for reception - localWantsRx: AsyncQueue[(PeerId, WakuMessageHash)] + localWantsRx: AsyncQueue[PeerId] localWantsRxFut: Future[void] - inSessions: Table[PeerId, HashSet[WakuMessageHash]] + inSessions: HashSet[PeerId] # Receive Hashes from reconciliation protocol for transmission remoteNeedsRx: AsyncQueue[(PeerId, WakuMessageHash)] @@ -78,19 +78,14 @@ proc openConnection( return ok(conn) proc wantsReceiverLoop(self: SyncTransfer) {.async.} = - ## Waits for message hashes, - ## store the peers and hashes locally as - ## "supposed to be received" + ## Waits for peer ids of nodes + ## we are reconciliating with while true: # infinite loop - let (peerId, fingerprint) = await self.localWantsRx.popFirst() + let peerId = await self.localWantsRx.popFirst() - self.inSessions.withValue(peerId, value): - value[].incl(fingerprint) - do: - var hashes = initHashSet[WakuMessageHash]() - hashes.incl(fingerprint) - self.inSessions[peerId] = hashes + if self.inSessions.containsOrIncl(peerId): + self.inSessions.excl(peerId) return @@ -137,6 +132,10 @@ proc needsReceiverLoop(self: SyncTransfer) {.async.} = proc initProtocolHandler(self: SyncTransfer) = let handler = proc(conn: Connection, proto: string) {.async, closure.} = while true: + if not self.inSessions.contains(conn.peerId): + error "unwanted peer, disconnecting", remote = conn.peerId + break + let readRes = catch: await conn.readLp(int64(DefaultMaxWakuMessageSize)) @@ -157,16 +156,6 @@ proc initProtocolHandler(self: SyncTransfer) = let hash = computeMessageHash(pubsub, msg) - self.inSessions.withValue(conn.peerId, value): - if value[].missingOrExcl(hash): - error "unwanted hash received, disconnecting" - self.inSessions.del(conn.peerId) - break - do: - error "unwanted hash received, disconnecting" - self.inSessions.del(conn.peerId) - break - #TODO verify msg RLN proof... (await self.wakuArchive.syncMessageIngress(hash, pubsub, msg)).isOkOr: @@ -193,7 +182,7 @@ proc new*( peerManager: PeerManager, wakuArchive: WakuArchive, idsTx: AsyncQueue[SyncID], - localWantsRx: AsyncQueue[(PeerId, WakuMessageHash)], + localWantsRx: AsyncQueue[PeerId], remoteNeedsRx: AsyncQueue[(PeerId, WakuMessageHash)], ): T = var transfer = SyncTransfer(