diff --git a/libp2p/protocols/rendezvous.nim b/libp2p/protocols/rendezvous.nim index ad8dbfebd..06155bad2 100644 --- a/libp2p/protocols/rendezvous.nim +++ b/libp2p/protocols/rendezvous.nim @@ -35,8 +35,6 @@ const RendezVousCodec* = "/rendezvous/1.0.0" MinimumDuration* = 2.hours MaximumDuration = 72.hours - MinimumTTL = MinimumDuration.seconds.uint64 - MaximumTTL = MaximumDuration.seconds.uint64 RegistrationLimitPerPeer = 1000 DiscoverLimit = 1000'u64 SemaphoreDefaultSize = 5 @@ -320,6 +318,10 @@ type peers: seq[PeerId] cookiesSaved: Table[PeerId, Table[string, seq[byte]]] switch: Switch + minDuration: Duration + maxDuration: Duration + minTTL: uint64 + maxTTL: uint64 proc checkPeerRecord(spr: seq[byte], peerId: PeerId): Result[void, string] = if spr.len == 0: @@ -395,7 +397,7 @@ proc save( rdv.registered.add( RegisteredData( peerId: peerId, - expiration: Moment.now() + r.ttl.get(MinimumTTL).int64.seconds, + expiration: Moment.now() + r.ttl.get(rdv.minTTL).int64.seconds, data: r, ) ) @@ -409,8 +411,8 @@ proc register(rdv: RendezVous, conn: Connection, r: Register): Future[void] = libp2p_rendezvous_register.inc() if r.ns.len notin 1 .. 255: return conn.sendRegisterResponseError(InvalidNamespace) - let ttl = r.ttl.get(MinimumTTL) - if ttl notin MinimumTTL .. MaximumTTL: + let ttl = r.ttl.get(rdv.minTTL) + if ttl notin rdv.minTTL .. rdv.maxTTL: return conn.sendRegisterResponseError(InvalidTTL) let pr = checkPeerRecord(r.signedPeerRecord, conn.peerId) if pr.isErr(): @@ -506,24 +508,35 @@ proc advertisePeer(rdv: RendezVous, peer: PeerId, msg: seq[byte]) {.async.} = await rdv.sema.acquire() discard await advertiseWrap().withTimeout(5.seconds) -method advertise*( - rdv: RendezVous, ns: string, ttl: Duration = MinimumDuration -) {.async, base.} = - let sprBuff = rdv.switch.peerInfo.signedPeerRecord.encode().valueOr: - raise newException(RendezVousError, "Wrong Signed Peer Record") +proc advertise*( + rdv: RendezVous, ns: string, ttl: Duration, peers: seq[PeerId] +) {.async.} = if ns.len notin 1 .. 255: raise newException(RendezVousError, "Invalid namespace") - if ttl notin MinimumDuration .. MaximumDuration: - raise newException(RendezVousError, "Invalid time to live") + + if ttl notin rdv.minDuration .. rdv.maxDuration: + raise newException(RendezVousError, "Invalid time to live: " & $ttl) + + let sprBuff = rdv.switch.peerInfo.signedPeerRecord.encode().valueOr: + raise newException(RendezVousError, "Wrong Signed Peer Record") + let r = Register(ns: ns, signedPeerRecord: sprBuff, ttl: Opt.some(ttl.seconds.uint64)) msg = encode(Message(msgType: MessageType.Register, register: Opt.some(r))) + rdv.save(ns, rdv.switch.peerInfo.peerId, r) - let fut = collect(newSeq()): - for peer in rdv.peers: + + let futs = collect(newSeq()): + for peer in peers: trace "Send Advertise", peerId = peer, ns rdv.advertisePeer(peer, msg.buffer) - await allFutures(fut) + + await allFutures(futs) + +method advertise*( + rdv: RendezVous, ns: string, ttl: Duration = rdv.minDuration +) {.async, base.} = + await rdv.advertise(ns, ttl, rdv.peers) proc requestLocally*(rdv: RendezVous, ns: string): seq[PeerRecord] = let @@ -540,9 +553,8 @@ proc requestLocally*(rdv: RendezVous, ns: string): seq[PeerRecord] = @[] proc request*( - rdv: RendezVous, ns: string, l: int = DiscoverLimit.int + rdv: RendezVous, ns: string, l: int = DiscoverLimit.int, peers: seq[PeerId] ): Future[seq[PeerRecord]] {.async.} = - let nsSalted = ns & rdv.salt var s: Table[PeerId, (PeerRecord, Register)] limit: uint64 @@ -587,8 +599,8 @@ proc request*( for r in resp.registrations: if limit == 0: return - let ttl = r.ttl.get(MaximumTTL + 1) - if ttl > MaximumTTL: + let ttl = r.ttl.get(rdv.maxTTL + 1) + if ttl > rdv.maxTTL: continue let spr = SignedPeerRecord.decode(r.signedPeerRecord).valueOr: @@ -596,7 +608,7 @@ proc request*( pr = spr.data if s.hasKey(pr.peerId): let (prSaved, rSaved) = s[pr.peerId] - if (prSaved.seqNo == pr.seqNo and rSaved.ttl.get(MaximumTTL) < ttl) or + if (prSaved.seqNo == pr.seqNo and rSaved.ttl.get(rdv.maxTTL) < ttl) or prSaved.seqNo < pr.seqNo: s[pr.peerId] = (pr, r) else: @@ -605,8 +617,6 @@ proc request*( for (_, r) in s.values(): rdv.save(ns, peer, r, false) - # copy to avoid resizes during the loop - let peers = rdv.peers for peer in peers: if limit == 0: break @@ -621,6 +631,11 @@ proc request*( trace "exception catch in request", description = exc.msg return toSeq(s.values()).mapIt(it[0]) +proc request*( + rdv: RendezVous, ns: string, l: int = DiscoverLimit.int +): Future[seq[PeerRecord]] {.async.} = + await rdv.request(ns, l, rdv.peers) + proc unsubscribeLocally*(rdv: RendezVous, ns: string) = let nsSalted = ns & rdv.salt try: @@ -630,16 +645,15 @@ proc unsubscribeLocally*(rdv: RendezVous, ns: string) = except KeyError: return -proc unsubscribe*(rdv: RendezVous, ns: string) {.async.} = - # TODO: find a way to improve this, maybe something similar to the advertise +proc unsubscribe*(rdv: RendezVous, ns: string, peerIds: seq[PeerId]) {.async.} = if ns.len notin 1 .. 255: raise newException(RendezVousError, "Invalid namespace") - rdv.unsubscribeLocally(ns) + let msg = encode( Message(msgType: MessageType.Unregister, unregister: Opt.some(Unregister(ns: ns))) ) - proc unsubscribePeer(rdv: RendezVous, peerId: PeerId) {.async.} = + proc unsubscribePeer(peerId: PeerId) {.async.} = try: let conn = await rdv.switch.dial(peerId, RendezVousCodec) defer: @@ -648,8 +662,16 @@ proc unsubscribe*(rdv: RendezVous, ns: string) {.async.} = except CatchableError as exc: trace "exception while unsubscribing", description = exc.msg - for peer in rdv.peers: - discard await rdv.unsubscribePeer(peer).withTimeout(5.seconds) + let futs = collect(newSeq()): + for peer in peerIds: + unsubscribePeer(peer) + + discard await allFutures(futs).withTimeout(5.seconds) + +proc unsubscribe*(rdv: RendezVous, ns: string) {.async.} = + rdv.unsubscribeLocally(ns) + + await rdv.unsubscribe(ns, rdv.peers) proc setup*(rdv: RendezVous, switch: Switch) = rdv.switch = switch @@ -662,7 +684,25 @@ proc setup*(rdv: RendezVous, switch: Switch) = rdv.switch.addPeerEventHandler(handlePeer, Joined) rdv.switch.addPeerEventHandler(handlePeer, Left) -proc new*(T: typedesc[RendezVous], rng: ref HmacDrbgContext = newRng()): T = +proc new*( + T: typedesc[RendezVous], + rng: ref HmacDrbgContext = newRng(), + minDuration = MinimumDuration, + maxDuration = MaximumDuration, +): T {.raises: [RendezVousError].} = + if minDuration < 1.minutes: + raise newException(RendezVousError, "TTL too short: 1 minute minimum") + + if maxDuration > 72.hours: + raise newException(RendezVousError, "TTL too long: 72 hours maximum") + + if minDuration >= maxDuration: + raise newException(RendezVousError, "Minimum TTL longer than maximum") + + let + minTTL = minDuration.seconds.uint64 + maxTTL = maxDuration.seconds.uint64 + let rdv = T( rng: rng, salt: string.fromBytes(generateBytes(rng[], 8)), @@ -670,6 +710,10 @@ proc new*(T: typedesc[RendezVous], rng: ref HmacDrbgContext = newRng()): T = defaultDT: Moment.now() - 1.days, #registerEvent: newAsyncEvent(), sema: newAsyncSemaphore(SemaphoreDefaultSize), + minDuration: minDuration, + maxDuration: maxDuration, + minTTL: minTTL, + maxTTL: maxTTL, ) logScope: topics = "libp2p discovery rendezvous" @@ -701,9 +745,13 @@ proc new*(T: typedesc[RendezVous], rng: ref HmacDrbgContext = newRng()): T = return rdv proc new*( - T: typedesc[RendezVous], switch: Switch, rng: ref HmacDrbgContext = newRng() + T: typedesc[RendezVous], + switch: Switch, + rng: ref HmacDrbgContext = newRng(), + minDuration = MinimumDuration, + maxDuration = MaximumDuration, ): T = - let rdv = T.new(rng) + let rdv = T.new(rng, minDuration, maxDuration) rdv.setup(switch) return rdv diff --git a/tests/testrendezvous.nim b/tests/testrendezvous.nim index 12494a13c..e0c4c2577 100644 --- a/tests/testrendezvous.nim +++ b/tests/testrendezvous.nim @@ -126,7 +126,7 @@ suite "RendezVous": asyncTest "Various local error": let - rdv = RendezVous.new() + rdv = RendezVous.new(minDuration = 1.minutes, maxDuration = 72.hours) switch = createSwitch(rdv) expect RendezVousError: discard await rdv.request("A".repeat(300)) @@ -137,6 +137,14 @@ suite "RendezVous": expect RendezVousError: await rdv.advertise("A".repeat(300)) expect RendezVousError: - await rdv.advertise("A", 2.weeks) + await rdv.advertise("A", 73.hours) expect RendezVousError: - await rdv.advertise("A", 5.minutes) + await rdv.advertise("A", 30.seconds) + + test "Various config error": + expect RendezVousError: + discard RendezVous.new(minDuration = 30.seconds) + expect RendezVousError: + discard RendezVous.new(maxDuration = 73.hours) + expect RendezVousError: + discard RendezVous.new(minDuration = 15.minutes, maxDuration = 10.minutes)