diff --git a/discovery_client.go b/discovery_client.go new file mode 100644 index 0000000..3a25253 --- /dev/null +++ b/discovery_client.go @@ -0,0 +1,134 @@ +package rendezvous + +import ( + "context" + "github.com/libp2p/go-libp2p-core/discovery" + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/peer" + "math" + "math/rand" + "sync" + "time" +) + +type rendezvousDiscoveryClient struct { + rp RendezvousPoint + peerCache sync.Map //is a map[string]discoveredPeerCache + rng *rand.Rand + rngMux sync.Mutex +} + +type discoveredPeerCache struct { + cachedRegs []Registration + mux sync.Mutex +} + +func NewRendezvousDiscoveryClient(host host.Host, rendezvousPeer peer.ID) discovery.Discovery { + rp := NewRendezvousPoint(host, rendezvousPeer) + return &rendezvousDiscoveryClient{rp, sync.Map{}, rand.New(rand.NewSource(rand.Int63())), sync.Mutex{}} +} + +func (c *rendezvousDiscoveryClient) Advertise(ctx context.Context, ns string, opts ...discovery.Option) (time.Duration, error) { + // Get options + var options discovery.Options + err := options.Apply(opts...) + if err != nil { + return 0, err + } + + ttl := options.Ttl + var ttlSeconds int + + // Default is minimum duration + if ttl == 0 { + ttlSeconds = 120 + } else { + ttlSeconds = int(math.Round(ttl.Seconds())) + } + + if err := c.rp.Register(ctx, ns, ttlSeconds); err != nil { + return 0, err + } + + actualTTL := time.Duration(ttlSeconds) * time.Second + return actualTTL, nil +} + +func (c *rendezvousDiscoveryClient) FindPeers(ctx context.Context, ns string, opts ...discovery.Option) (<-chan peer.AddrInfo, error) { + // Get options + var options discovery.Options + err := options.Apply(opts...) + if err != nil { + return nil, err + } + + const maxLimit = 1000 + limit := options.Limit + if limit == 0 || limit > maxLimit { + limit = maxLimit + } + + // Get cached peers + var cache *discoveredPeerCache + + genericCache, _ := c.peerCache.LoadOrStore(ns, &discoveredPeerCache{}) + cache = genericCache.(*discoveredPeerCache) + + cache.mux.Lock() + cachedRegs := cache.cachedRegs + + // Remove all expired entries from cache + currentTime := int(time.Now().Unix()) + newCacheSize := len(cachedRegs) + + for i := 0; i < newCacheSize; i++ { + reg := cachedRegs[i] + if reg.Ttl < currentTime { + newCacheSize-- + if i != newCacheSize { + cachedRegs[i] = cachedRegs[newCacheSize] + i-- + } + } + } + cache.cachedRegs = cachedRegs[:newCacheSize] + cache.mux.Unlock() + + // Discover new records if we don't have enough + var discoveryErr error + if newCacheSize < limit { + if discoveryRecords, _, err := c.rp.Discover(ctx, ns, limit, nil); err == nil { + for i := range discoveryRecords { + discoveryRecords[i].Ttl += currentTime + } + cache.mux.Lock() + cache.cachedRegs = discoveryRecords + cache.mux.Unlock() + } else { + // TODO: Should we return error even if we have valid cached results? + discoveryErr = err + } + } + + // Randomize and fill channel with available records + cache.mux.Lock() + cachedRegs = cache.cachedRegs + sendQuantity := len(cachedRegs) + if limit < sendQuantity { + sendQuantity = limit + } + + chPeer := make(chan peer.AddrInfo, sendQuantity) + + c.rngMux.Lock() + perm := c.rng.Perm(len(cachedRegs))[0:sendQuantity] + c.rngMux.Unlock() + + for _, i := range perm { + chPeer <- cachedRegs[i].Peer + } + + cache.mux.Unlock() + close(chPeer) + return chPeer, discoveryErr +} diff --git a/discovery_client_test.go b/discovery_client_test.go new file mode 100644 index 0000000..854737e --- /dev/null +++ b/discovery_client_test.go @@ -0,0 +1,163 @@ +package rendezvous + +import ( + "context" + "github.com/libp2p/go-libp2p-core/discovery" + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/peer" + "math/rand" + "sync" + "testing" + "time" +) + +func getRendezvousDiscovery(hosts []host.Host) []discovery.Discovery { + clients := make([]discovery.Discovery, len(hosts)-1) + rendezvousPeer := hosts[0].ID() + for i, h := range hosts[1:] { + rp := NewRendezvousPoint(h, rendezvousPeer) + rng := rand.New(rand.NewSource(int64(i))) + clients[i] = &rendezvousDiscoveryClient{rp: rp, peerCache: sync.Map{}, rng: rng} + } + return clients +} + +func peerChannelToArray(pch <-chan peer.AddrInfo) []peer.AddrInfo { + pi := make([]peer.AddrInfo, len(pch)) + peerIndex := 0 + for p := range pch { + pi[peerIndex] = p + peerIndex++ + } + return pi +} + +func checkAvailablePeers(t *testing.T, ctx context.Context, client discovery.Discovery, namespace string, expectedNumPeers int) { + pch, err := client.FindPeers(ctx, namespace) + if err != nil { + t.Fatal(err) + } + + pi := peerChannelToArray(pch) + + if len(pi) != expectedNumPeers { + t.Fatalf("Expected %d peers", expectedNumPeers) + } +} + +func TestDiscoveryClientAdvertiseAndFindPeers(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Define parameters + const namespace = "foo1" + const numClients = 4 + const ttl = DefaultTTL * time.Second + + // Instantiate server and clients + hosts := getRendezvousHosts(t, ctx, numClients+1) + + svc, err := makeRendezvousService(ctx, hosts[0], ":memory:") + if err != nil { + t.Fatal(err) + } + defer svc.DB.Close() + + clients := getRendezvousDiscovery(hosts) + + // Advertise and check one peer + _, err = clients[0].Advertise(ctx, namespace, discovery.TTL(ttl)) + if err != nil { + t.Fatal(err) + } + + checkAvailablePeers(t, ctx, clients[0], namespace, 1) + + // Advertise and check the rest of the peers incrementally + for i, client := range clients[1:] { + if _, err = client.Advertise(ctx, namespace, discovery.TTL(ttl)); err != nil { + t.Fatal(err) + } + + checkAvailablePeers(t, ctx, client, namespace, i+2) + } + + // Check that the first peer can get all the new records + checkAvailablePeers(t, ctx, clients[0], namespace, numClients) +} + +func TestDiscoveryClientExpiredCachedRecords(t *testing.T) { + BaseDiscoveryClientCacheExpirationTest(t, true) +} + +func TestDiscoveryClientExpiredManyCachedRecords(t *testing.T) { + BaseDiscoveryClientCacheExpirationTest(t, false) +} + +func BaseDiscoveryClientCacheExpirationTest(t *testing.T, onlyRequestFromCache bool) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Define parameters + const numShortLivedRegs = 5 + const everyIthRegIsLongTTL = 2 + const numBaseRegs = numShortLivedRegs * everyIthRegIsLongTTL + const namespace = "foo1" + const longTTL = DefaultTTL * time.Second + const shortTTL = 2 * time.Second + + // Instantiate server and clients + hosts := getRendezvousHosts(t, ctx, numBaseRegs+3) + + svc, err := makeRendezvousService(ctx, hosts[0], ":memory:") + if err != nil { + t.Fatal(err) + } + defer svc.DB.Close() + clients := getRendezvousDiscovery(hosts) + + // Advertise most clients + for i, client := range clients[2:] { + ttl := shortTTL + if i%everyIthRegIsLongTTL == 0 { + ttl = longTTL + } + + if _, err = client.Advertise(ctx, namespace, discovery.TTL(ttl)); err != nil { + t.Fatal(err) + } + } + + // Find peers from an unrelated client (results should be cached) + pch, err := clients[0].FindPeers(ctx, namespace) + if err != nil { + t.Fatal(err) + } + pi := peerChannelToArray(pch) + if len(pi) != numBaseRegs { + t.Fatalf("expected %d registrations", numBaseRegs) + } + + // Advertise from a new unrelated peer + if _, err := clients[1].Advertise(ctx, namespace, discovery.TTL(longTTL)); err != nil { + t.Fatal(err) + } + + // Wait for cache expiration + time.Sleep(shortTTL + time.Second) + + // Check if number of retrieved records matches caching expectations after expiration + expectedNumClients := numShortLivedRegs + if !onlyRequestFromCache { + expectedNumClients++ + } + pch, err = clients[0].FindPeers(ctx, namespace, discovery.Limit(expectedNumClients)) + if err != nil { + t.Fatal(err) + } + pi = peerChannelToArray(pch) + + if len(pi) != expectedNumClients { + t.Fatalf("received an incorrect number of records: %d", len(pi)) + } +}