From 75136303114a195329f71b7feaa294054060af31 Mon Sep 17 00:00:00 2001 From: Ludovic Chenut Date: Thu, 15 Dec 2022 13:39:53 +0100 Subject: [PATCH] Another round of fixes --- libp2p/protocols/connectivity/autorelay.nim | 66 ++++++++++++--------- tests/testautorelay.nim | 10 ++-- 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/libp2p/protocols/connectivity/autorelay.nim b/libp2p/protocols/connectivity/autorelay.nim index 2058befab..7abcd02ff 100644 --- a/libp2p/protocols/connectivity/autorelay.nim +++ b/libp2p/protocols/connectivity/autorelay.nim @@ -20,7 +20,8 @@ logScope: topics = "libp2p autorelay" type - OnReservationHandler = proc (ma: MultiAddress): Future[void] {.gcsafe, raises: [Defect].} + OnReservationHandler = proc (addresses: seq[MultiAddress]): + Future[void] {.gcsafe, raises: [Defect].} AutoRelayService* = ref object of Service running: bool @@ -29,8 +30,10 @@ type numRelays: int relayPeers: Table[PeerId, Future[void]] relayAddresses: Table[PeerId, MultiAddress] - peerJoined: AsyncEvent + backedOff: seq[PeerId] + peerAvailable: AsyncEvent onReservation: OnReservationHandler + rng: ref HmacDrbgContext proc reserveAndUpdate(self: AutoRelayService, relayPid: PeerId, selfPid: PeerId) {.async.} = while self.running: @@ -39,54 +42,59 @@ proc reserveAndUpdate(self: AutoRelayService, relayPid: PeerId, selfPid: PeerId) relayedAddr = MultiAddress.init($(rsvp.addrs[0]) & "/p2p-circuit/p2p/" & $selfPid).tryGet() - if not self.onReservation.isNil(): - await self.onReservation(relayedAddr) - self.relayAddresses[relayPid] = relayedAddr - await sleepAsync chronos.seconds(rsvp.expire.int64 - times.now().utc.toTime.toUnix) + if relayPid notin self.relayAddresses or self.relayAddresses[relayPid] != relayedAddr: + self.relayAddresses[relayPid] = relayedAddr + if not self.onReservation.isNil(): + await self.onReservation(toSeq(self.relayAddresses.values)) + await sleepAsync chronos.seconds(rsvp.expire.int64 - times.now().utc.toTime.toUnix - 5) method setup*(self: AutoRelayService, switch: Switch): Future[bool] {.async, gcsafe.} = let hasBeenSetUp = await procCall Service(self).setup(switch) if hasBeenSetUp: proc handlePeerJoined(peerId: PeerId, event: PeerEvent) {.async.} = if self.relayPeers.len < self.numRelays: - self.peerJoined.fire() + self.peerAvailable.fire() proc handlePeerLeft(peerId: PeerId, event: PeerEvent) {.async.} = - if peerId in self.relayPeers: - self.relayPeers[peerId].cancel() + self.relayPeers.withValue(peerId, future): + future[].cancel() switch.addPeerEventHandler(handlePeerJoined, Joined) switch.addPeerEventHandler(handlePeerLeft, Left) return hasBeenSetUp -method innerRun(self: AutoRelayService, switch: Switch) {.async, gcsafe.} = +proc manageBackedOff(self: AutoRelayService, pid: PeerId) {.async.} = + await sleepAsync(chronos.seconds(5)) + self.backedOff.keepItIf(it != pid) + self.peerAvailable.fire() + +proc innerRun(self: AutoRelayService, switch: Switch) {.async, gcsafe.} = while true: # Remove relayPeers that failed - var peersToRemove: seq[PeerId] - for k, v in self.relayPeers: - if v.finished(): - peersToRemove.add(k) - for k in peersToRemove: - self.relayPeers.del(k) - self.relayAddresses.del(k) - if peersToRemove.len() > 0: - await sleepAsync(500.millis) # To avoid ddosing our relayPeers in certain condition + let peers = toSeq(self.relayPeers.keys()) + for k in peers: + if self.relayPeers[k].finished(): + self.relayPeers.del(k) + self.relayAddresses.del(k) + if not self.onReservation.isNil(): + await self.onReservation(toSeq(self.relayAddresses.values)) + # To avoid ddosing our peers in certain conditions + self.backedOff.add(k) + asyncSpawn self.manageBackedOff(k) # Get all connected relayPeers - let rng = newRng() var connectedPeers = switch.connectedPeers(Direction.Out) connectedPeers.keepItIf(RelayV2HopCodec in switch.peerStore[ProtoBook][it] or it notin self.relayPeers) - rng.shuffle(connectedPeers) + self.rng.shuffle(connectedPeers) for relayPid in connectedPeers: if self.relayPeers.len() >= self.numRelays: break - if RelayV2HopCodec in switch.peerStore[ProtoBook][relayPid]: - self.relayPeers[relayPid] = self.reserveAndUpdate(relayPid, switch.peerInfo.peerId) + self.relayPeers[relayPid] = self.reserveAndUpdate(relayPid, switch.peerInfo.peerId) let peersFutures = toSeq(self.relayPeers.values()) if self.relayPeers.len() < self.numRelays: - self.peerJoined.clear() - await one(peersFutures) or self.peerJoined.wait() + self.peerAvailable.clear() + await one(peersFutures) or self.peerAvailable.wait() else: discard await one(peersFutures) @@ -104,14 +112,16 @@ method stop*(self: AutoRelayService, switch: Switch): Future[bool] {.async, gcsa self.runner.cancel() return hasBeenStopped -method getAddresses*(self: AutoRelayService): seq[MultiAddress] = +proc getAddresses*(self: AutoRelayService): seq[MultiAddress] = result = toSeq(self.relayAddresses.values) proc new*(T: typedesc[AutoRelayService], numRelays: int, client: RelayClient, - onReservation: OnReservationHandler): T = + onReservation: OnReservationHandler, + rng: ref HmacDrbgContext): T = T(numRelays: numRelays, client: client, onReservation: onReservation, - peerJoined: newAsyncEvent()) + peerAvailable: newAsyncEvent(), + rng: rng) diff --git a/tests/testautorelay.nim b/tests/testautorelay.nim index 3f09165c3..c26565d27 100644 --- a/tests/testautorelay.nim +++ b/tests/testautorelay.nim @@ -30,12 +30,12 @@ suite "Autorelay": let client = RelayClient.new() let switch = createSwitch(client) let fut = newFuture[void]() - proc checkMA(address: MultiAddress) {.async.} = - check: address == MultiAddress.init($relay.peerInfo.addrs[0] & "/p2p/" & - $relay.peerInfo.peerId & "/p2p-circuit/p2p/" & - $switch.peerInfo.peerId).get() + proc checkMA(address: seq[MultiAddress]) {.async.} = + check: address[0] == MultiAddress.init($relay.peerInfo.addrs[0] & "/p2p/" & + $relay.peerInfo.peerId & "/p2p-circuit/p2p/" & + $switch.peerInfo.peerId).get() fut.complete() - let autorelay = AutoRelayService.new(3, client, checkMA) + let autorelay = AutoRelayService.new(3, client, checkMA, newRng()) switch.addService(autorelay) await switch.start() await relay.start()