diff --git a/client.go b/client.go index e416433..6cd091b 100644 --- a/client.go +++ b/client.go @@ -66,7 +66,7 @@ type rendezvousClient struct { func (rp *rendezvousPoint) Register(ctx context.Context, ns string, ttl int) (time.Duration, error) { s, err := rp.host.NewStream(ctx, rp.p, RendezvousProto) if err != nil { - return err + return 0, err } defer s.Close() @@ -134,7 +134,7 @@ func registerRefresh(ctx context.Context, rz RendezvousPoint, ns string, ttl int return } - err := rz.Register(ctx, ns, ttl) + _, err := rz.Register(ctx, ns, ttl) if err != nil { log.Errorf("Error registering [%s]: %s", ns, err.Error()) errcount++ diff --git a/client_test.go b/client_test.go index 6470932..db974b7 100644 --- a/client_test.go +++ b/client_test.go @@ -31,10 +31,13 @@ func TestClientRegistrationAndDiscovery(t *testing.T) { clients := getRendezvousClients(t, hosts) - err = clients[0].Register(ctx, "foo1", DefaultTTL) + recordTTL, err := clients[0].Register(ctx, "foo1", DefaultTTL) if err != nil { t.Fatal(err) } + if recordTTL != DefaultTTL*time.Second { + t.Fatalf("Expected record TTL to be %d seconds", DefaultTTL) + } pi, cookie, err := clients[0].Discover(ctx, "foo1", 0, nil) if err != nil { @@ -46,10 +49,13 @@ func TestClientRegistrationAndDiscovery(t *testing.T) { checkPeerInfo(t, pi[0], hosts[1]) for i, client := range clients[1:] { - err = client.Register(ctx, "foo1", DefaultTTL) + recordTTL, err = client.Register(ctx, "foo1", DefaultTTL) if err != nil { t.Fatal(err) } + if recordTTL != DefaultTTL*time.Second { + t.Fatalf("Expected record TTL to be %d seconds", DefaultTTL) + } pi, cookie, err = clients[0].Discover(ctx, "foo1", 10, cookie) if err != nil { @@ -98,10 +104,13 @@ func TestClientRegistrationAndDiscoveryAsync(t *testing.T) { } for i, client := range clients[0:] { - err = client.Register(ctx, "foo1", DefaultTTL) + recordTTL, err := client.Register(ctx, "foo1", DefaultTTL) if err != nil { t.Fatal(err) } + if recordTTL != DefaultTTL*time.Second { + t.Fatalf("Expected record TTL to be %d seconds", DefaultTTL) + } pi := <-ch checkPeerInfo(t, pi, hosts[1+i]) diff --git a/discovery_client.go b/discovery.go similarity index 52% rename from discovery_client.go rename to discovery.go index 30b9ffd..94cac73 100644 --- a/discovery_client.go +++ b/discovery.go @@ -11,7 +11,7 @@ import ( "time" ) -type rendezvousDiscoveryClient struct { +type rendezvousDiscovery struct { rp RendezvousPoint peerCache sync.Map //is a map[string]discoveredPeerCache rng *rand.Rand @@ -19,17 +19,22 @@ type rendezvousDiscoveryClient struct { } type discoveredPeerCache struct { - cachedRegs map[peer.ID]*Registration + cachedRecs map[peer.ID]*record cookie []byte mux sync.Mutex } -func NewRendezvousDiscoveryClient(host host.Host, rendezvousPeer peer.ID) discovery.Discovery { - rp := NewRendezvousPoint(host, rendezvousPeer) - return &rendezvousDiscoveryClient{rp, sync.Map{}, rand.New(rand.NewSource(rand.Int63())), sync.Mutex{}} +type record struct { + peer peer.AddrInfo + expire int64 } -func (c *rendezvousDiscoveryClient) Advertise(ctx context.Context, ns string, opts ...discovery.Option) (time.Duration, error) { +func NewRendezvousDiscovery(host host.Host, rendezvousPeer peer.ID) discovery.Discovery { + rp := NewRendezvousPoint(host, rendezvousPeer) + return &rendezvousDiscovery{rp, sync.Map{}, rand.New(rand.NewSource(rand.Int63())), sync.Mutex{}} +} + +func (c *rendezvousDiscovery) Advertise(ctx context.Context, ns string, opts ...discovery.Option) (time.Duration, error) { // Get options var options discovery.Options err := options.Apply(opts...) @@ -53,7 +58,7 @@ func (c *rendezvousDiscoveryClient) Advertise(ctx context.Context, ns string, op } } -func (c *rendezvousDiscoveryClient) FindPeers(ctx context.Context, ns string, opts ...discovery.Option) (<-chan peer.AddrInfo, error) { +func (c *rendezvousDiscovery) FindPeers(ctx context.Context, ns string, opts ...discovery.Option) (<-chan peer.AddrInfo, error) { // Get options var options discovery.Options err := options.Apply(opts...) @@ -74,54 +79,49 @@ func (c *rendezvousDiscoveryClient) FindPeers(ctx context.Context, ns string, op cache = genericCache.(*discoveredPeerCache) cache.mux.Lock() + defer cache.mux.Unlock() // Remove all expired entries from cache - currentTime := int(time.Now().Unix()) - newCacheSize := len(cache.cachedRegs) + currentTime := time.Now().Unix() + newCacheSize := len(cache.cachedRecs) - for p := range cache.cachedRegs { - reg := cache.cachedRegs[p] - if reg.Ttl < currentTime { + for p := range cache.cachedRecs { + rec := cache.cachedRecs[p] + if rec.expire < currentTime { newCacheSize-- - delete(cache.cachedRegs, p) + delete(cache.cachedRecs, p) } } cookie := cache.cookie - cache.mux.Unlock() // Discover new records if we don't have enough - var discoveryErr error if newCacheSize < limit { - if discoveryRecords, newCookie, err := c.rp.Discover(ctx, ns, limit, cookie); err == nil { - cache.mux.Lock() - if cache.cachedRegs == nil { - cache.cachedRegs = make(map[peer.ID]*Registration) + // TODO: Should we return error even if we have valid cached results? + var regs []Registration + var newCookie []byte + if regs, newCookie, err = c.rp.Discover(ctx, ns, limit, cookie); err == nil { + if cache.cachedRecs == nil { + cache.cachedRecs = make(map[peer.ID]*record) } - for i := range discoveryRecords { - rec := &discoveryRecords[i] - rec.Ttl += currentTime - cache.cachedRegs[rec.Peer.ID] = rec + for _, reg := range regs { + rec := &record{peer: reg.Peer, expire: int64(reg.Ttl) + currentTime} + cache.cachedRecs[rec.peer.ID] = rec } cache.cookie = newCookie - cache.mux.Unlock() - } else { - // TODO: Should we return error even if we have valid cached results? - discoveryErr = err } } // Randomize and fill channel with available records - cache.mux.Lock() - sendQuantity := len(cache.cachedRegs) - if limit < sendQuantity { - sendQuantity = limit + count := len(cache.cachedRecs) + if limit < count { + count = limit } - chPeer := make(chan peer.AddrInfo, sendQuantity) + chPeer := make(chan peer.AddrInfo, count) c.rngMux.Lock() - perm := c.rng.Perm(len(cache.cachedRegs))[0:sendQuantity] + perm := c.rng.Perm(len(cache.cachedRecs))[0:count] c.rngMux.Unlock() permSet := make(map[int]int) @@ -129,11 +129,11 @@ func (c *rendezvousDiscoveryClient) FindPeers(ctx context.Context, ns string, op permSet[v] = i } - sendLst := make([]*peer.AddrInfo, sendQuantity) + sendLst := make([]*peer.AddrInfo, count) iter := 0 - for k := range cache.cachedRegs { + for k := range cache.cachedRecs { if sendIndex, ok := permSet[iter]; ok { - sendLst[sendIndex] = &cache.cachedRegs[k].Peer + sendLst[sendIndex] = &cache.cachedRecs[k].peer } iter++ } @@ -142,7 +142,6 @@ func (c *rendezvousDiscoveryClient) FindPeers(ctx context.Context, ns string, op chPeer <- *send } - cache.mux.Unlock() close(chPeer) - return chPeer, discoveryErr + return chPeer, err } diff --git a/discovery_client_test.go b/discovery_test.go similarity index 98% rename from discovery_client_test.go rename to discovery_test.go index 854737e..4091ef7 100644 --- a/discovery_client_test.go +++ b/discovery_test.go @@ -17,7 +17,7 @@ func getRendezvousDiscovery(hosts []host.Host) []discovery.Discovery { for i, h := range hosts[1:] { rp := NewRendezvousPoint(h, rendezvousPeer) rng := rand.New(rand.NewSource(int64(i))) - clients[i] = &rendezvousDiscoveryClient{rp: rp, peerCache: sync.Map{}, rng: rng} + clients[i] = &rendezvousDiscovery{rp: rp, peerCache: sync.Map{}, rng: rng} } return clients } diff --git a/svc_test.go b/svc_test.go index eaea97d..cc311d1 100644 --- a/svc_test.go +++ b/svc_test.go @@ -5,6 +5,7 @@ import ( "fmt" "math/rand" "testing" + "time" db "github.com/libp2p/go-libp2p-rendezvous/db/sqlite" pb "github.com/libp2p/go-libp2p-rendezvous/pb" @@ -77,10 +78,14 @@ func TestSVCRegistrationAndDiscovery(t *testing.T) { clients := getRendezvousPoints(t, hosts) - err = clients[0].Register(ctx, "foo1", 60) + const registerTTL = 60 + recordTTL, err := clients[0].Register(ctx, "foo1", registerTTL) if err != nil { t.Fatal(err) } + if recordTTL != registerTTL*time.Second { + t.Fatalf("Expected record TTL to be %d seconds", DefaultTTL) + } rrs, cookie, err := clients[0].Discover(ctx, "foo1", 10, nil) if err != nil { @@ -92,10 +97,13 @@ func TestSVCRegistrationAndDiscovery(t *testing.T) { checkHostRegistration(t, rrs[0], hosts[1]) for i, client := range clients[1:] { - err = client.Register(ctx, "foo1", 60) + recordTTL, err = client.Register(ctx, "foo1", registerTTL) if err != nil { t.Fatal(err) } + if recordTTL != registerTTL*time.Second { + t.Fatalf("Expected record TTL to be %d seconds", DefaultTTL) + } rrs, cookie, err = clients[0].Discover(ctx, "foo1", 10, cookie) if err != nil {