feat: refactor waku sync DOS protection (#3391)

This commit is contained in:
Simon-Pierre Vivier 2025-04-24 09:07:21 -04:00 committed by GitHub
parent 8394c15a1a
commit 0c63ce4e9b
5 changed files with 34 additions and 58 deletions

View File

@ -1,4 +1,4 @@
import std/[options, random], chronos
import std/[options, random], chronos, chronicles
import
waku/[
@ -23,7 +23,7 @@ proc randomHash*(rng: var Rand): WakuMessageHash =
proc newTestWakuRecon*(
switch: Switch,
idsRx: AsyncQueue[SyncID],
wantsTx: AsyncQueue[(PeerId, Fingerprint)],
wantsTx: AsyncQueue[PeerId],
needsTx: AsyncQueue[(PeerId, Fingerprint)],
cluster: uint16 = 1,
shards: seq[uint16] = @[0, 1, 2, 3, 4, 5, 6, 7],
@ -51,7 +51,7 @@ proc newTestWakuRecon*(
proc newTestWakuTransfer*(
switch: Switch,
idsTx: AsyncQueue[SyncID],
wantsRx: AsyncQueue[(PeerId, Fingerprint)],
wantsRx: AsyncQueue[PeerId],
needsRx: AsyncQueue[(PeerId, Fingerprint)],
): SyncTransfer =
let peerManager = PeerManager.new(switch)

View File

@ -27,7 +27,7 @@ suite "Waku Sync: reconciliation":
var
idsChannel {.threadvar.}: AsyncQueue[SyncID]
localWants {.threadvar.}: AsyncQueue[(PeerId, WakuMessageHash)]
localWants {.threadvar.}: AsyncQueue[PeerId]
remoteNeeds {.threadvar.}: AsyncQueue[(PeerId, WakuMessageHash)]
var server {.threadvar.}: SyncReconciliation
@ -43,7 +43,7 @@ suite "Waku Sync: reconciliation":
await allFutures(serverSwitch.start(), clientSwitch.start())
idsChannel = newAsyncQueue[SyncID]()
localWants = newAsyncQueue[(PeerId, WakuMessageHash)]()
localWants = newAsyncQueue[PeerId]()
remoteNeeds = newAsyncQueue[(PeerId, WakuMessageHash)]()
server = await newTestWakuRecon(serverSwitch, idsChannel, localWants, remoteNeeds)
@ -61,7 +61,6 @@ suite "Waku Sync: reconciliation":
asyncTest "sync 2 nodes both empty":
check:
idsChannel.len == 0
localWants.len == 0
remoteNeeds.len == 0
let res = await client.storeSynchronization(some(serverPeerInfo))
@ -69,7 +68,6 @@ suite "Waku Sync: reconciliation":
check:
idsChannel.len == 0
localWants.len == 0
remoteNeeds.len == 0
asyncTest "sync 2 nodes empty client full server":
@ -141,8 +139,6 @@ suite "Waku Sync: reconciliation":
check:
remoteNeeds.contains((serverPeerInfo.peerId, hash3)) == false
remoteNeeds.contains((clientPeerInfo.peerId, hash2)) == false
localWants.contains((clientPeerInfo.peerId, hash3)) == false
localWants.contains((serverPeerInfo.peerId, hash2)) == false
var syncRes = await client.storeSynchronization(some(serverPeerInfo))
assert syncRes.isOk(), $syncRes.error
@ -150,8 +146,6 @@ suite "Waku Sync: reconciliation":
check:
remoteNeeds.contains((serverPeerInfo.peerId, hash3)) == true
remoteNeeds.contains((clientPeerInfo.peerId, hash2)) == true
localWants.contains((clientPeerInfo.peerId, hash3)) == true
localWants.contains((serverPeerInfo.peerId, hash2)) == true
asyncTest "sync 2 nodes different shards":
let
@ -170,8 +164,6 @@ suite "Waku Sync: reconciliation":
check:
remoteNeeds.contains((serverPeerInfo.peerId, hash3)) == false
remoteNeeds.contains((clientPeerInfo.peerId, hash2)) == false
localWants.contains((clientPeerInfo.peerId, hash3)) == false
localWants.contains((serverPeerInfo.peerId, hash2)) == false
server = await newTestWakuRecon(
serverSwitch, idsChannel, localWants, remoteNeeds, shards = @[0.uint16, 1, 2, 3]
@ -185,7 +177,6 @@ suite "Waku Sync: reconciliation":
check:
remoteNeeds.len == 0
localWants.len == 0
asyncTest "sync 2 nodes same hashes":
let
@ -200,14 +191,12 @@ suite "Waku Sync: reconciliation":
client.messageIngress(hash2, msg2)
check:
localWants.len == 0
remoteNeeds.len == 0
let res = await client.storeSynchronization(some(serverPeerInfo))
assert res.isOk(), $res.error
check:
localWants.len == 0
remoteNeeds.len == 0
asyncTest "sync 2 nodes 100K msgs 1 diff":
@ -236,14 +225,12 @@ suite "Waku Sync: reconciliation":
timestamp += Timestamp(part)
check:
localWants.contains((serverPeerInfo.peerId, WakuMessageHash(diff))) == false
remoteNeeds.contains((clientPeerInfo.peerId, WakuMessageHash(diff))) == false
let res = await client.storeSynchronization(some(serverPeerInfo))
assert res.isOk(), $res.error
check:
localWants.contains((serverPeerInfo.peerId, WakuMessageHash(diff))) == true
remoteNeeds.contains((clientPeerInfo.peerId, WakuMessageHash(diff))) == true
asyncTest "sync 2 nodes 10K msgs 1K diffs":
@ -286,7 +273,6 @@ suite "Waku Sync: reconciliation":
continue
check:
localWants.len == 0
remoteNeeds.len == 0
let res = await client.storeSynchronization(some(serverPeerInfo))
@ -294,7 +280,6 @@ suite "Waku Sync: reconciliation":
# timimg issue make it hard to match exact numbers
check:
localWants.len > 900
remoteNeeds.len > 900
suite "Waku Sync: transfer":
@ -310,10 +295,10 @@ suite "Waku Sync: transfer":
var
serverIds {.threadvar.}: AsyncQueue[SyncID]
serverLocalWants {.threadvar.}: AsyncQueue[(PeerId, WakuMessageHash)]
serverLocalWants {.threadvar.}: AsyncQueue[PeerId]
serverRemoteNeeds {.threadvar.}: AsyncQueue[(PeerId, WakuMessageHash)]
clientIds {.threadvar.}: AsyncQueue[SyncID]
clientLocalWants {.threadvar.}: AsyncQueue[(PeerId, WakuMessageHash)]
clientLocalWants {.threadvar.}: AsyncQueue[PeerId]
clientRemoteNeeds {.threadvar.}: AsyncQueue[(PeerId, WakuMessageHash)]
var
@ -341,7 +326,7 @@ suite "Waku Sync: transfer":
clientPeerManager = PeerManager.new(clientSwitch)
serverIds = newAsyncQueue[SyncID]()
serverLocalWants = newAsyncQueue[(PeerId, WakuMessageHash)]()
serverLocalWants = newAsyncQueue[PeerId]()
serverRemoteNeeds = newAsyncQueue[(PeerId, WakuMessageHash)]()
server = SyncTransfer.new(
@ -353,7 +338,7 @@ suite "Waku Sync: transfer":
)
clientIds = newAsyncQueue[SyncID]()
clientLocalWants = newAsyncQueue[(PeerId, WakuMessageHash)]()
clientLocalWants = newAsyncQueue[PeerId]()
clientRemoteNeeds = newAsyncQueue[(PeerId, WakuMessageHash)]()
client = SyncTransfer.new(
@ -389,8 +374,8 @@ suite "Waku Sync: transfer":
serverDriver = serverDriver.put(DefaultPubsubTopic, msgs)
# add server info and msg hash to client want channel
let want = (serverPeerInfo.peerId, hash)
# add server info to client want channel
let want = serverPeerInfo.peerId
await clientLocalWants.put(want)
# add client info and msg hash to server need channel

View File

@ -212,7 +212,7 @@ proc mountStoreSync*(
storeSyncRelayJitter = 20,
): Future[Result[void, string]] {.async.} =
let idsChannel = newAsyncQueue[SyncID](0)
let wantsChannel = newAsyncQueue[(PeerId, WakuMessageHash)](0)
let wantsChannel = newAsyncQueue[PeerId](0)
let needsChannel = newAsyncQueue[(PeerId, WakuMessageHash)](0)
var cluster: uint16

View File

@ -46,13 +46,10 @@ type SyncReconciliation* = ref object of LPProtocol
storage: SyncStorage
# Receive IDs from transfer protocol for storage
# AsyncQueues are used as communication channels between
# reconciliation and transfer protocols.
idsRx: AsyncQueue[SyncID]
# Send Hashes to transfer protocol for reception
localWantsTx: AsyncQueue[(PeerId, WakuMessageHash)]
# Send Hashes to transfer protocol for transmission
localWantsTx: AsyncQueue[PeerId]
remoteNeedsTx: AsyncQueue[(PeerId, WakuMessageHash)]
# params
@ -100,6 +97,9 @@ proc processRequest(
roundTrips = 0
diffs = 0
# Signal to transfer protocol that this reconciliation is starting
await self.localWantsTx.addLast(conn.peerId)
while true:
let readRes = catch:
await conn.readLp(int.high)
@ -143,7 +143,6 @@ proc processRequest(
diffs.inc()
for hash in hashToRecv:
self.localWantsTx.addLastNoWait((conn.peerId, hash))
diffs.inc()
rawPayload = sendPayload.deltaEncode()
@ -168,6 +167,9 @@ proc processRequest(
continue
# Signal to transfer protocol that this reconciliation is done
await self.localWantsTx.addLast(conn.peerId)
reconciliation_roundtrips.observe(roundTrips)
reconciliation_differences.observe(diffs)
@ -296,7 +298,7 @@ proc new*(
syncInterval: timer.Duration = DefaultSyncInterval,
relayJitter: timer.Duration = DefaultGossipSubJitter,
idsRx: AsyncQueue[SyncID],
localWantsTx: AsyncQueue[(PeerId, WakuMessageHash)],
localWantsTx: AsyncQueue[PeerId],
remoteNeedsTx: AsyncQueue[(PeerId, WakuMessageHash)],
): Future[Result[T, string]] {.async.} =
let res = await initFillStorage(syncRange, wakuArchive)

View File

@ -37,9 +37,9 @@ type SyncTransfer* = ref object of LPProtocol
idsTx: AsyncQueue[SyncID]
# Receive Hashes from reconciliation protocol for reception
localWantsRx: AsyncQueue[(PeerId, WakuMessageHash)]
localWantsRx: AsyncQueue[PeerId]
localWantsRxFut: Future[void]
inSessions: Table[PeerId, HashSet[WakuMessageHash]]
inSessions: HashSet[PeerId]
# Receive Hashes from reconciliation protocol for transmission
remoteNeedsRx: AsyncQueue[(PeerId, WakuMessageHash)]
@ -78,19 +78,14 @@ proc openConnection(
return ok(conn)
proc wantsReceiverLoop(self: SyncTransfer) {.async.} =
## Waits for message hashes,
## store the peers and hashes locally as
## "supposed to be received"
## Waits for peer ids of nodes
## we are reconciliating with
while true: # infinite loop
let (peerId, fingerprint) = await self.localWantsRx.popFirst()
let peerId = await self.localWantsRx.popFirst()
self.inSessions.withValue(peerId, value):
value[].incl(fingerprint)
do:
var hashes = initHashSet[WakuMessageHash]()
hashes.incl(fingerprint)
self.inSessions[peerId] = hashes
if self.inSessions.containsOrIncl(peerId):
self.inSessions.excl(peerId)
return
@ -137,6 +132,10 @@ proc needsReceiverLoop(self: SyncTransfer) {.async.} =
proc initProtocolHandler(self: SyncTransfer) =
let handler = proc(conn: Connection, proto: string) {.async, closure.} =
while true:
if not self.inSessions.contains(conn.peerId):
error "unwanted peer, disconnecting", remote = conn.peerId
break
let readRes = catch:
await conn.readLp(int64(DefaultMaxWakuMessageSize))
@ -157,16 +156,6 @@ proc initProtocolHandler(self: SyncTransfer) =
let hash = computeMessageHash(pubsub, msg)
self.inSessions.withValue(conn.peerId, value):
if value[].missingOrExcl(hash):
error "unwanted hash received, disconnecting"
self.inSessions.del(conn.peerId)
break
do:
error "unwanted hash received, disconnecting"
self.inSessions.del(conn.peerId)
break
#TODO verify msg RLN proof...
(await self.wakuArchive.syncMessageIngress(hash, pubsub, msg)).isOkOr:
@ -193,7 +182,7 @@ proc new*(
peerManager: PeerManager,
wakuArchive: WakuArchive,
idsTx: AsyncQueue[SyncID],
localWantsRx: AsyncQueue[(PeerId, WakuMessageHash)],
localWantsRx: AsyncQueue[PeerId],
remoteNeedsRx: AsyncQueue[(PeerId, WakuMessageHash)],
): T =
var transfer = SyncTransfer(