feat: rendezvous refactor (#1183)

Hello!

This PR aim to refactor rendezvous code so that it is easier to impl.
Waku rdv strategy. The hardcoded min and max TTL were out of range with
what we needed and specifying which peers to interact with is also
needed since Waku deals with peers on multiple separate shards.

I tried to keep the changes to a minimum, specifically I did not change
the name of any public procs which result in less than descriptive names
in some cases. I also wanted to return results instead of raising
exceptions but didn't. Would it be acceptable to do so?

Please advise on best practices, thank you.

---------

Co-authored-by: Ludovic Chenut <ludovic@status.im>
This commit is contained in:
Simon-Pierre Vivier 2024-09-25 05:11:57 -04:00 committed by GitHub
parent 09fe199b6b
commit d389d96789
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 90 additions and 34 deletions

View File

@ -35,8 +35,6 @@ const
RendezVousCodec* = "/rendezvous/1.0.0" RendezVousCodec* = "/rendezvous/1.0.0"
MinimumDuration* = 2.hours MinimumDuration* = 2.hours
MaximumDuration = 72.hours MaximumDuration = 72.hours
MinimumTTL = MinimumDuration.seconds.uint64
MaximumTTL = MaximumDuration.seconds.uint64
RegistrationLimitPerPeer = 1000 RegistrationLimitPerPeer = 1000
DiscoverLimit = 1000'u64 DiscoverLimit = 1000'u64
SemaphoreDefaultSize = 5 SemaphoreDefaultSize = 5
@ -320,6 +318,10 @@ type
peers: seq[PeerId] peers: seq[PeerId]
cookiesSaved: Table[PeerId, Table[string, seq[byte]]] cookiesSaved: Table[PeerId, Table[string, seq[byte]]]
switch: Switch switch: Switch
minDuration: Duration
maxDuration: Duration
minTTL: uint64
maxTTL: uint64
proc checkPeerRecord(spr: seq[byte], peerId: PeerId): Result[void, string] = proc checkPeerRecord(spr: seq[byte], peerId: PeerId): Result[void, string] =
if spr.len == 0: if spr.len == 0:
@ -395,7 +397,7 @@ proc save(
rdv.registered.add( rdv.registered.add(
RegisteredData( RegisteredData(
peerId: peerId, peerId: peerId,
expiration: Moment.now() + r.ttl.get(MinimumTTL).int64.seconds, expiration: Moment.now() + r.ttl.get(rdv.minTTL).int64.seconds,
data: r, data: r,
) )
) )
@ -409,8 +411,8 @@ proc register(rdv: RendezVous, conn: Connection, r: Register): Future[void] =
libp2p_rendezvous_register.inc() libp2p_rendezvous_register.inc()
if r.ns.len notin 1 .. 255: if r.ns.len notin 1 .. 255:
return conn.sendRegisterResponseError(InvalidNamespace) return conn.sendRegisterResponseError(InvalidNamespace)
let ttl = r.ttl.get(MinimumTTL) let ttl = r.ttl.get(rdv.minTTL)
if ttl notin MinimumTTL .. MaximumTTL: if ttl notin rdv.minTTL .. rdv.maxTTL:
return conn.sendRegisterResponseError(InvalidTTL) return conn.sendRegisterResponseError(InvalidTTL)
let pr = checkPeerRecord(r.signedPeerRecord, conn.peerId) let pr = checkPeerRecord(r.signedPeerRecord, conn.peerId)
if pr.isErr(): if pr.isErr():
@ -506,24 +508,35 @@ proc advertisePeer(rdv: RendezVous, peer: PeerId, msg: seq[byte]) {.async.} =
await rdv.sema.acquire() await rdv.sema.acquire()
discard await advertiseWrap().withTimeout(5.seconds) discard await advertiseWrap().withTimeout(5.seconds)
method advertise*( proc advertise*(
rdv: RendezVous, ns: string, ttl: Duration = MinimumDuration rdv: RendezVous, ns: string, ttl: Duration, peers: seq[PeerId]
) {.async, base.} = ) {.async.} =
let sprBuff = rdv.switch.peerInfo.signedPeerRecord.encode().valueOr:
raise newException(RendezVousError, "Wrong Signed Peer Record")
if ns.len notin 1 .. 255: if ns.len notin 1 .. 255:
raise newException(RendezVousError, "Invalid namespace") raise newException(RendezVousError, "Invalid namespace")
if ttl notin MinimumDuration .. MaximumDuration:
raise newException(RendezVousError, "Invalid time to live") if ttl notin rdv.minDuration .. rdv.maxDuration:
raise newException(RendezVousError, "Invalid time to live: " & $ttl)
let sprBuff = rdv.switch.peerInfo.signedPeerRecord.encode().valueOr:
raise newException(RendezVousError, "Wrong Signed Peer Record")
let let
r = Register(ns: ns, signedPeerRecord: sprBuff, ttl: Opt.some(ttl.seconds.uint64)) r = Register(ns: ns, signedPeerRecord: sprBuff, ttl: Opt.some(ttl.seconds.uint64))
msg = encode(Message(msgType: MessageType.Register, register: Opt.some(r))) msg = encode(Message(msgType: MessageType.Register, register: Opt.some(r)))
rdv.save(ns, rdv.switch.peerInfo.peerId, r) rdv.save(ns, rdv.switch.peerInfo.peerId, r)
let fut = collect(newSeq()):
for peer in rdv.peers: let futs = collect(newSeq()):
for peer in peers:
trace "Send Advertise", peerId = peer, ns trace "Send Advertise", peerId = peer, ns
rdv.advertisePeer(peer, msg.buffer) rdv.advertisePeer(peer, msg.buffer)
await allFutures(fut)
await allFutures(futs)
method advertise*(
rdv: RendezVous, ns: string, ttl: Duration = rdv.minDuration
) {.async, base.} =
await rdv.advertise(ns, ttl, rdv.peers)
proc requestLocally*(rdv: RendezVous, ns: string): seq[PeerRecord] = proc requestLocally*(rdv: RendezVous, ns: string): seq[PeerRecord] =
let let
@ -540,9 +553,8 @@ proc requestLocally*(rdv: RendezVous, ns: string): seq[PeerRecord] =
@[] @[]
proc request*( proc request*(
rdv: RendezVous, ns: string, l: int = DiscoverLimit.int rdv: RendezVous, ns: string, l: int = DiscoverLimit.int, peers: seq[PeerId]
): Future[seq[PeerRecord]] {.async.} = ): Future[seq[PeerRecord]] {.async.} =
let nsSalted = ns & rdv.salt
var var
s: Table[PeerId, (PeerRecord, Register)] s: Table[PeerId, (PeerRecord, Register)]
limit: uint64 limit: uint64
@ -587,8 +599,8 @@ proc request*(
for r in resp.registrations: for r in resp.registrations:
if limit == 0: if limit == 0:
return return
let ttl = r.ttl.get(MaximumTTL + 1) let ttl = r.ttl.get(rdv.maxTTL + 1)
if ttl > MaximumTTL: if ttl > rdv.maxTTL:
continue continue
let let
spr = SignedPeerRecord.decode(r.signedPeerRecord).valueOr: spr = SignedPeerRecord.decode(r.signedPeerRecord).valueOr:
@ -596,7 +608,7 @@ proc request*(
pr = spr.data pr = spr.data
if s.hasKey(pr.peerId): if s.hasKey(pr.peerId):
let (prSaved, rSaved) = s[pr.peerId] let (prSaved, rSaved) = s[pr.peerId]
if (prSaved.seqNo == pr.seqNo and rSaved.ttl.get(MaximumTTL) < ttl) or if (prSaved.seqNo == pr.seqNo and rSaved.ttl.get(rdv.maxTTL) < ttl) or
prSaved.seqNo < pr.seqNo: prSaved.seqNo < pr.seqNo:
s[pr.peerId] = (pr, r) s[pr.peerId] = (pr, r)
else: else:
@ -605,8 +617,6 @@ proc request*(
for (_, r) in s.values(): for (_, r) in s.values():
rdv.save(ns, peer, r, false) rdv.save(ns, peer, r, false)
# copy to avoid resizes during the loop
let peers = rdv.peers
for peer in peers: for peer in peers:
if limit == 0: if limit == 0:
break break
@ -621,6 +631,11 @@ proc request*(
trace "exception catch in request", description = exc.msg trace "exception catch in request", description = exc.msg
return toSeq(s.values()).mapIt(it[0]) return toSeq(s.values()).mapIt(it[0])
proc request*(
rdv: RendezVous, ns: string, l: int = DiscoverLimit.int
): Future[seq[PeerRecord]] {.async.} =
await rdv.request(ns, l, rdv.peers)
proc unsubscribeLocally*(rdv: RendezVous, ns: string) = proc unsubscribeLocally*(rdv: RendezVous, ns: string) =
let nsSalted = ns & rdv.salt let nsSalted = ns & rdv.salt
try: try:
@ -630,16 +645,15 @@ proc unsubscribeLocally*(rdv: RendezVous, ns: string) =
except KeyError: except KeyError:
return return
proc unsubscribe*(rdv: RendezVous, ns: string) {.async.} = proc unsubscribe*(rdv: RendezVous, ns: string, peerIds: seq[PeerId]) {.async.} =
# TODO: find a way to improve this, maybe something similar to the advertise
if ns.len notin 1 .. 255: if ns.len notin 1 .. 255:
raise newException(RendezVousError, "Invalid namespace") raise newException(RendezVousError, "Invalid namespace")
rdv.unsubscribeLocally(ns)
let msg = encode( let msg = encode(
Message(msgType: MessageType.Unregister, unregister: Opt.some(Unregister(ns: ns))) Message(msgType: MessageType.Unregister, unregister: Opt.some(Unregister(ns: ns)))
) )
proc unsubscribePeer(rdv: RendezVous, peerId: PeerId) {.async.} = proc unsubscribePeer(peerId: PeerId) {.async.} =
try: try:
let conn = await rdv.switch.dial(peerId, RendezVousCodec) let conn = await rdv.switch.dial(peerId, RendezVousCodec)
defer: defer:
@ -648,8 +662,16 @@ proc unsubscribe*(rdv: RendezVous, ns: string) {.async.} =
except CatchableError as exc: except CatchableError as exc:
trace "exception while unsubscribing", description = exc.msg trace "exception while unsubscribing", description = exc.msg
for peer in rdv.peers: let futs = collect(newSeq()):
discard await rdv.unsubscribePeer(peer).withTimeout(5.seconds) for peer in peerIds:
unsubscribePeer(peer)
discard await allFutures(futs).withTimeout(5.seconds)
proc unsubscribe*(rdv: RendezVous, ns: string) {.async.} =
rdv.unsubscribeLocally(ns)
await rdv.unsubscribe(ns, rdv.peers)
proc setup*(rdv: RendezVous, switch: Switch) = proc setup*(rdv: RendezVous, switch: Switch) =
rdv.switch = switch rdv.switch = switch
@ -662,7 +684,25 @@ proc setup*(rdv: RendezVous, switch: Switch) =
rdv.switch.addPeerEventHandler(handlePeer, Joined) rdv.switch.addPeerEventHandler(handlePeer, Joined)
rdv.switch.addPeerEventHandler(handlePeer, Left) rdv.switch.addPeerEventHandler(handlePeer, Left)
proc new*(T: typedesc[RendezVous], rng: ref HmacDrbgContext = newRng()): T = proc new*(
T: typedesc[RendezVous],
rng: ref HmacDrbgContext = newRng(),
minDuration = MinimumDuration,
maxDuration = MaximumDuration,
): T {.raises: [RendezVousError].} =
if minDuration < 1.minutes:
raise newException(RendezVousError, "TTL too short: 1 minute minimum")
if maxDuration > 72.hours:
raise newException(RendezVousError, "TTL too long: 72 hours maximum")
if minDuration >= maxDuration:
raise newException(RendezVousError, "Minimum TTL longer than maximum")
let
minTTL = minDuration.seconds.uint64
maxTTL = maxDuration.seconds.uint64
let rdv = T( let rdv = T(
rng: rng, rng: rng,
salt: string.fromBytes(generateBytes(rng[], 8)), salt: string.fromBytes(generateBytes(rng[], 8)),
@ -670,6 +710,10 @@ proc new*(T: typedesc[RendezVous], rng: ref HmacDrbgContext = newRng()): T =
defaultDT: Moment.now() - 1.days, defaultDT: Moment.now() - 1.days,
#registerEvent: newAsyncEvent(), #registerEvent: newAsyncEvent(),
sema: newAsyncSemaphore(SemaphoreDefaultSize), sema: newAsyncSemaphore(SemaphoreDefaultSize),
minDuration: minDuration,
maxDuration: maxDuration,
minTTL: minTTL,
maxTTL: maxTTL,
) )
logScope: logScope:
topics = "libp2p discovery rendezvous" topics = "libp2p discovery rendezvous"
@ -701,9 +745,13 @@ proc new*(T: typedesc[RendezVous], rng: ref HmacDrbgContext = newRng()): T =
return rdv return rdv
proc new*( proc new*(
T: typedesc[RendezVous], switch: Switch, rng: ref HmacDrbgContext = newRng() T: typedesc[RendezVous],
switch: Switch,
rng: ref HmacDrbgContext = newRng(),
minDuration = MinimumDuration,
maxDuration = MaximumDuration,
): T = ): T =
let rdv = T.new(rng) let rdv = T.new(rng, minDuration, maxDuration)
rdv.setup(switch) rdv.setup(switch)
return rdv return rdv

View File

@ -126,7 +126,7 @@ suite "RendezVous":
asyncTest "Various local error": asyncTest "Various local error":
let let
rdv = RendezVous.new() rdv = RendezVous.new(minDuration = 1.minutes, maxDuration = 72.hours)
switch = createSwitch(rdv) switch = createSwitch(rdv)
expect RendezVousError: expect RendezVousError:
discard await rdv.request("A".repeat(300)) discard await rdv.request("A".repeat(300))
@ -137,6 +137,14 @@ suite "RendezVous":
expect RendezVousError: expect RendezVousError:
await rdv.advertise("A".repeat(300)) await rdv.advertise("A".repeat(300))
expect RendezVousError: expect RendezVousError:
await rdv.advertise("A", 2.weeks) await rdv.advertise("A", 73.hours)
expect RendezVousError: expect RendezVousError:
await rdv.advertise("A", 5.minutes) await rdv.advertise("A", 30.seconds)
test "Various config error":
expect RendezVousError:
discard RendezVous.new(minDuration = 30.seconds)
expect RendezVousError:
discard RendezVous.new(maxDuration = 73.hours)
expect RendezVousError:
discard RendezVous.new(minDuration = 15.minutes, maxDuration = 10.minutes)