From ba3a59f2257f5889148705b8b74566495fa6a05d Mon Sep 17 00:00:00 2001 From: Richard Ramos Date: Thu, 30 Sep 2021 10:39:04 -0400 Subject: [PATCH] refactor: obtain the peers from peerstore --- client.go | 44 ++++++++++++++++++++++++++++++++++---------- discovery.go | 4 ++-- proto.go | 5 ++--- svc.go | 2 +- 4 files changed, 39 insertions(+), 16 deletions(-) diff --git a/client.go b/client.go index 576e697..1978894 100644 --- a/client.go +++ b/client.go @@ -37,20 +37,18 @@ type RendezvousClient interface { DiscoverAsync(ctx context.Context, ns string) (<-chan peer.AddrInfo, error) } -func NewRendezvousPoint(host host.Host, p []peer.ID) RendezvousPoint { +func NewRendezvousPoint(host host.Host) RendezvousPoint { return &rendezvousPoint{ host: host, - p: p, } } type rendezvousPoint struct { host host.Host - p []peer.ID } -func NewRendezvousClient(host host.Host, rp []peer.ID) RendezvousClient { - return NewRendezvousClientWithPoint(NewRendezvousPoint(host, rp)) +func NewRendezvousClient(host host.Host) RendezvousClient { + return NewRendezvousClientWithPoint(NewRendezvousPoint(host)) } func NewRendezvousClientWithPoint(rp RendezvousPoint) RendezvousClient { @@ -61,12 +59,28 @@ type rendezvousClient struct { rp RendezvousPoint } -func (r *rendezvousPoint) getRandomPeer() peer.ID { - return r.p[rand.Intn(len(r.p))] // nolint: gosec +func (r *rendezvousPoint) getRandomPeer() (peer.ID, error) { + var peerIDs []peer.ID + for _, peer := range r.host.Peerstore().Peers() { + protocols, err := r.host.Peerstore().SupportsProtocols(peer, string(RendezvousID_v001)) + if err != nil { + log.Error("error obtaining the protocols supported by peers", err) + return "", err + } + if len(protocols) > 0 { + peerIDs = append(peerIDs, peer) + } + } + return peerIDs[rand.Intn(len(peerIDs))], nil // nolint: gosec } func (rp *rendezvousPoint) Register(ctx context.Context, ns string, ttl int) (time.Duration, error) { - s, err := rp.host.NewStream(ctx, rp.getRandomPeer(), RendezvousProto) + randomPeer, err := rp.getRandomPeer() + if err != nil { + return 0, err + } + + s, err := rp.host.NewStream(ctx, randomPeer, RendezvousID_v001) if err != nil { return 0, err } @@ -147,7 +161,12 @@ func registerRefresh(ctx context.Context, rz RendezvousPoint, ns string, ttl int } func (rp *rendezvousPoint) Discover(ctx context.Context, ns string, limit int) ([]Registration, error) { - s, err := rp.host.NewStream(ctx, rp.getRandomPeer(), RendezvousProto) + randomPeer, err := rp.getRandomPeer() + if err != nil { + return nil, err + } + + s, err := rp.host.NewStream(ctx, randomPeer, RendezvousID_v001) if err != nil { return nil, err } @@ -196,7 +215,12 @@ func discoverQuery(ns string, limit int, r ggio.Reader, w ggio.Writer) ([]Regist } func (rp *rendezvousPoint) DiscoverAsync(ctx context.Context, ns string) (<-chan Registration, error) { - s, err := rp.host.NewStream(ctx, rp.getRandomPeer(), RendezvousProto) + randomPeer, err := rp.getRandomPeer() + if err != nil { + return nil, err + } + + s, err := rp.host.NewStream(ctx, randomPeer, RendezvousID_v001) if err != nil { return nil, err } diff --git a/discovery.go b/discovery.go index 7d2f73b..0a9d0ce 100644 --- a/discovery.go +++ b/discovery.go @@ -31,8 +31,8 @@ type record struct { expire int64 } -func NewRendezvousDiscovery(host host.Host, rendezvousPeers []peer.ID) discovery.Discovery { - rp := NewRendezvousPoint(host, rendezvousPeers) +func NewRendezvousDiscovery(host host.Host) discovery.Discovery { + rp := NewRendezvousPoint(host) return &rendezvousDiscovery{rp: rp, peerCache: make(map[string]*discoveryCache), rng: rand.New(rand.NewSource(rand.Int63()))} } diff --git a/proto.go b/proto.go index 6464565..9687a16 100644 --- a/proto.go +++ b/proto.go @@ -15,9 +15,8 @@ import ( var log = logging.Logger("rendezvous") const ( - RendezvousProto = protocol.ID("/vac/waku/rendezvous/0.0.1") - - DefaultTTL = 2 * 3600 // 2hr + RendezvousID_v001 = protocol.ID("/vac/waku/rendezvous/0.0.1") + DefaultTTL = 2 * 3600 // 2hr ) type RendezvousError struct { diff --git a/svc.go b/svc.go index 66fb2f1..759168b 100644 --- a/svc.go +++ b/svc.go @@ -48,7 +48,7 @@ func NewRendezvousService(host host.Host, storage Storage, rzs ...RendezvousSync } func (rz *RendezvousService) Start() error { - rz.h.SetStreamHandler(RendezvousProto, rz.handleStream) + rz.h.SetStreamHandler(RendezvousID_v001, rz.handleStream) if err := rz.startCleaner(); err != nil { return err