From d389d9678930599d0496c39697d2816b2b9bedb9 Mon Sep 17 00:00:00 2001 From: Simon-Pierre Vivier Date: Wed, 25 Sep 2024 05:11:57 -0400 Subject: [PATCH] feat: rendezvous refactor (#1183) Hello! This PR aim to refactor rendezvous code so that it is easier to impl. Waku rdv strategy. The hardcoded min and max TTL were out of range with what we needed and specifying which peers to interact with is also needed since Waku deals with peers on multiple separate shards. I tried to keep the changes to a minimum, specifically I did not change the name of any public procs which result in less than descriptive names in some cases. I also wanted to return results instead of raising exceptions but didn't. Would it be acceptable to do so? Please advise on best practices, thank you. --------- Co-authored-by: Ludovic Chenut --- libp2p/protocols/rendezvous.nim | 110 +++++++++++++++++++++++--------- tests/testrendezvous.nim | 14 +++- 2 files changed, 90 insertions(+), 34 deletions(-) diff --git a/libp2p/protocols/rendezvous.nim b/libp2p/protocols/rendezvous.nim index ad8dbfebd..06155bad2 100644 --- a/libp2p/protocols/rendezvous.nim +++ b/libp2p/protocols/rendezvous.nim @@ -35,8 +35,6 @@ const RendezVousCodec* = "/rendezvous/1.0.0" MinimumDuration* = 2.hours MaximumDuration = 72.hours - MinimumTTL = MinimumDuration.seconds.uint64 - MaximumTTL = MaximumDuration.seconds.uint64 RegistrationLimitPerPeer = 1000 DiscoverLimit = 1000'u64 SemaphoreDefaultSize = 5 @@ -320,6 +318,10 @@ type peers: seq[PeerId] cookiesSaved: Table[PeerId, Table[string, seq[byte]]] switch: Switch + minDuration: Duration + maxDuration: Duration + minTTL: uint64 + maxTTL: uint64 proc checkPeerRecord(spr: seq[byte], peerId: PeerId): Result[void, string] = if spr.len == 0: @@ -395,7 +397,7 @@ proc save( rdv.registered.add( RegisteredData( peerId: peerId, - expiration: Moment.now() + r.ttl.get(MinimumTTL).int64.seconds, + expiration: Moment.now() + r.ttl.get(rdv.minTTL).int64.seconds, data: r, ) ) @@ -409,8 +411,8 @@ proc register(rdv: RendezVous, conn: Connection, r: Register): Future[void] = libp2p_rendezvous_register.inc() if r.ns.len notin 1 .. 255: return conn.sendRegisterResponseError(InvalidNamespace) - let ttl = r.ttl.get(MinimumTTL) - if ttl notin MinimumTTL .. MaximumTTL: + let ttl = r.ttl.get(rdv.minTTL) + if ttl notin rdv.minTTL .. rdv.maxTTL: return conn.sendRegisterResponseError(InvalidTTL) let pr = checkPeerRecord(r.signedPeerRecord, conn.peerId) if pr.isErr(): @@ -506,24 +508,35 @@ proc advertisePeer(rdv: RendezVous, peer: PeerId, msg: seq[byte]) {.async.} = await rdv.sema.acquire() discard await advertiseWrap().withTimeout(5.seconds) -method advertise*( - rdv: RendezVous, ns: string, ttl: Duration = MinimumDuration -) {.async, base.} = - let sprBuff = rdv.switch.peerInfo.signedPeerRecord.encode().valueOr: - raise newException(RendezVousError, "Wrong Signed Peer Record") +proc advertise*( + rdv: RendezVous, ns: string, ttl: Duration, peers: seq[PeerId] +) {.async.} = if ns.len notin 1 .. 255: raise newException(RendezVousError, "Invalid namespace") - if ttl notin MinimumDuration .. MaximumDuration: - raise newException(RendezVousError, "Invalid time to live") + + if ttl notin rdv.minDuration .. rdv.maxDuration: + raise newException(RendezVousError, "Invalid time to live: " & $ttl) + + let sprBuff = rdv.switch.peerInfo.signedPeerRecord.encode().valueOr: + raise newException(RendezVousError, "Wrong Signed Peer Record") + let r = Register(ns: ns, signedPeerRecord: sprBuff, ttl: Opt.some(ttl.seconds.uint64)) msg = encode(Message(msgType: MessageType.Register, register: Opt.some(r))) + rdv.save(ns, rdv.switch.peerInfo.peerId, r) - let fut = collect(newSeq()): - for peer in rdv.peers: + + let futs = collect(newSeq()): + for peer in peers: trace "Send Advertise", peerId = peer, ns rdv.advertisePeer(peer, msg.buffer) - await allFutures(fut) + + await allFutures(futs) + +method advertise*( + rdv: RendezVous, ns: string, ttl: Duration = rdv.minDuration +) {.async, base.} = + await rdv.advertise(ns, ttl, rdv.peers) proc requestLocally*(rdv: RendezVous, ns: string): seq[PeerRecord] = let @@ -540,9 +553,8 @@ proc requestLocally*(rdv: RendezVous, ns: string): seq[PeerRecord] = @[] proc request*( - rdv: RendezVous, ns: string, l: int = DiscoverLimit.int + rdv: RendezVous, ns: string, l: int = DiscoverLimit.int, peers: seq[PeerId] ): Future[seq[PeerRecord]] {.async.} = - let nsSalted = ns & rdv.salt var s: Table[PeerId, (PeerRecord, Register)] limit: uint64 @@ -587,8 +599,8 @@ proc request*( for r in resp.registrations: if limit == 0: return - let ttl = r.ttl.get(MaximumTTL + 1) - if ttl > MaximumTTL: + let ttl = r.ttl.get(rdv.maxTTL + 1) + if ttl > rdv.maxTTL: continue let spr = SignedPeerRecord.decode(r.signedPeerRecord).valueOr: @@ -596,7 +608,7 @@ proc request*( pr = spr.data if s.hasKey(pr.peerId): let (prSaved, rSaved) = s[pr.peerId] - if (prSaved.seqNo == pr.seqNo and rSaved.ttl.get(MaximumTTL) < ttl) or + if (prSaved.seqNo == pr.seqNo and rSaved.ttl.get(rdv.maxTTL) < ttl) or prSaved.seqNo < pr.seqNo: s[pr.peerId] = (pr, r) else: @@ -605,8 +617,6 @@ proc request*( for (_, r) in s.values(): rdv.save(ns, peer, r, false) - # copy to avoid resizes during the loop - let peers = rdv.peers for peer in peers: if limit == 0: break @@ -621,6 +631,11 @@ proc request*( trace "exception catch in request", description = exc.msg return toSeq(s.values()).mapIt(it[0]) +proc request*( + rdv: RendezVous, ns: string, l: int = DiscoverLimit.int +): Future[seq[PeerRecord]] {.async.} = + await rdv.request(ns, l, rdv.peers) + proc unsubscribeLocally*(rdv: RendezVous, ns: string) = let nsSalted = ns & rdv.salt try: @@ -630,16 +645,15 @@ proc unsubscribeLocally*(rdv: RendezVous, ns: string) = except KeyError: return -proc unsubscribe*(rdv: RendezVous, ns: string) {.async.} = - # TODO: find a way to improve this, maybe something similar to the advertise +proc unsubscribe*(rdv: RendezVous, ns: string, peerIds: seq[PeerId]) {.async.} = if ns.len notin 1 .. 255: raise newException(RendezVousError, "Invalid namespace") - rdv.unsubscribeLocally(ns) + let msg = encode( Message(msgType: MessageType.Unregister, unregister: Opt.some(Unregister(ns: ns))) ) - proc unsubscribePeer(rdv: RendezVous, peerId: PeerId) {.async.} = + proc unsubscribePeer(peerId: PeerId) {.async.} = try: let conn = await rdv.switch.dial(peerId, RendezVousCodec) defer: @@ -648,8 +662,16 @@ proc unsubscribe*(rdv: RendezVous, ns: string) {.async.} = except CatchableError as exc: trace "exception while unsubscribing", description = exc.msg - for peer in rdv.peers: - discard await rdv.unsubscribePeer(peer).withTimeout(5.seconds) + let futs = collect(newSeq()): + for peer in peerIds: + unsubscribePeer(peer) + + discard await allFutures(futs).withTimeout(5.seconds) + +proc unsubscribe*(rdv: RendezVous, ns: string) {.async.} = + rdv.unsubscribeLocally(ns) + + await rdv.unsubscribe(ns, rdv.peers) proc setup*(rdv: RendezVous, switch: Switch) = rdv.switch = switch @@ -662,7 +684,25 @@ proc setup*(rdv: RendezVous, switch: Switch) = rdv.switch.addPeerEventHandler(handlePeer, Joined) rdv.switch.addPeerEventHandler(handlePeer, Left) -proc new*(T: typedesc[RendezVous], rng: ref HmacDrbgContext = newRng()): T = +proc new*( + T: typedesc[RendezVous], + rng: ref HmacDrbgContext = newRng(), + minDuration = MinimumDuration, + maxDuration = MaximumDuration, +): T {.raises: [RendezVousError].} = + if minDuration < 1.minutes: + raise newException(RendezVousError, "TTL too short: 1 minute minimum") + + if maxDuration > 72.hours: + raise newException(RendezVousError, "TTL too long: 72 hours maximum") + + if minDuration >= maxDuration: + raise newException(RendezVousError, "Minimum TTL longer than maximum") + + let + minTTL = minDuration.seconds.uint64 + maxTTL = maxDuration.seconds.uint64 + let rdv = T( rng: rng, salt: string.fromBytes(generateBytes(rng[], 8)), @@ -670,6 +710,10 @@ proc new*(T: typedesc[RendezVous], rng: ref HmacDrbgContext = newRng()): T = defaultDT: Moment.now() - 1.days, #registerEvent: newAsyncEvent(), sema: newAsyncSemaphore(SemaphoreDefaultSize), + minDuration: minDuration, + maxDuration: maxDuration, + minTTL: minTTL, + maxTTL: maxTTL, ) logScope: topics = "libp2p discovery rendezvous" @@ -701,9 +745,13 @@ proc new*(T: typedesc[RendezVous], rng: ref HmacDrbgContext = newRng()): T = return rdv proc new*( - T: typedesc[RendezVous], switch: Switch, rng: ref HmacDrbgContext = newRng() + T: typedesc[RendezVous], + switch: Switch, + rng: ref HmacDrbgContext = newRng(), + minDuration = MinimumDuration, + maxDuration = MaximumDuration, ): T = - let rdv = T.new(rng) + let rdv = T.new(rng, minDuration, maxDuration) rdv.setup(switch) return rdv diff --git a/tests/testrendezvous.nim b/tests/testrendezvous.nim index 12494a13c..e0c4c2577 100644 --- a/tests/testrendezvous.nim +++ b/tests/testrendezvous.nim @@ -126,7 +126,7 @@ suite "RendezVous": asyncTest "Various local error": let - rdv = RendezVous.new() + rdv = RendezVous.new(minDuration = 1.minutes, maxDuration = 72.hours) switch = createSwitch(rdv) expect RendezVousError: discard await rdv.request("A".repeat(300)) @@ -137,6 +137,14 @@ suite "RendezVous": expect RendezVousError: await rdv.advertise("A".repeat(300)) expect RendezVousError: - await rdv.advertise("A", 2.weeks) + await rdv.advertise("A", 73.hours) expect RendezVousError: - await rdv.advertise("A", 5.minutes) + await rdv.advertise("A", 30.seconds) + + test "Various config error": + expect RendezVousError: + discard RendezVous.new(minDuration = 30.seconds) + expect RendezVousError: + discard RendezVous.new(maxDuration = 73.hours) + expect RendezVousError: + discard RendezVous.new(minDuration = 15.minutes, maxDuration = 10.minutes)