diff --git a/discovery.go b/discovery.go index 94cac73..40e70e8 100644 --- a/discovery.go +++ b/discovery.go @@ -12,16 +12,17 @@ import ( ) type rendezvousDiscovery struct { - rp RendezvousPoint - peerCache sync.Map //is a map[string]discoveredPeerCache - rng *rand.Rand - rngMux sync.Mutex + rp RendezvousPoint + peerCache map[string]*discoveryCache + peerCacheMux sync.RWMutex + rng *rand.Rand + rngMux sync.Mutex } -type discoveredPeerCache struct { - cachedRecs map[peer.ID]*record - cookie []byte - mux sync.Mutex +type discoveryCache struct { + recs map[peer.ID]*record + cookie []byte + mux sync.Mutex } type record struct { @@ -31,7 +32,7 @@ type record struct { func NewRendezvousDiscovery(host host.Host, rendezvousPeer peer.ID) discovery.Discovery { rp := NewRendezvousPoint(host, rendezvousPeer) - return &rendezvousDiscovery{rp, sync.Map{}, rand.New(rand.NewSource(rand.Int63())), sync.Mutex{}} + return &rendezvousDiscovery{rp: rp, peerCache: make(map[string]*discoveryCache), rng: rand.New(rand.NewSource(rand.Int63()))} } func (c *rendezvousDiscovery) Advertise(ctx context.Context, ns string, opts ...discovery.Option) (time.Duration, error) { @@ -51,10 +52,10 @@ func (c *rendezvousDiscovery) Advertise(ctx context.Context, ns string, opts ... ttlSeconds = int(math.Round(ttl.Seconds())) } - if returnedTTL, err := c.rp.Register(ctx, ns, ttlSeconds); err != nil { + if rttl, err := c.rp.Register(ctx, ns, ttlSeconds); err != nil { return 0, err } else { - return returnedTTL, nil + return rttl, nil } } @@ -73,23 +74,33 @@ func (c *rendezvousDiscovery) FindPeers(ctx context.Context, ns string, opts ... } // Get cached peers - var cache *discoveredPeerCache + var cache *discoveryCache - genericCache, _ := c.peerCache.LoadOrStore(ns, &discoveredPeerCache{}) - cache = genericCache.(*discoveredPeerCache) + c.peerCacheMux.RLock() + cache, ok := c.peerCache[ns] + c.peerCacheMux.RUnlock() + if !ok { + c.peerCacheMux.Lock() + cache, ok = c.peerCache[ns] + if !ok{ + cache = &discoveryCache{recs: make(map[peer.ID]*record)} + c.peerCache[ns] = cache + } + c.peerCacheMux.Unlock() + } cache.mux.Lock() defer cache.mux.Unlock() // Remove all expired entries from cache currentTime := time.Now().Unix() - newCacheSize := len(cache.cachedRecs) + newCacheSize := len(cache.recs) - for p := range cache.cachedRecs { - rec := cache.cachedRecs[p] + for p := range cache.recs { + rec := cache.recs[p] if rec.expire < currentTime { newCacheSize-- - delete(cache.cachedRecs, p) + delete(cache.recs, p) } } @@ -101,19 +112,16 @@ func (c *rendezvousDiscovery) FindPeers(ctx context.Context, ns string, opts ... var regs []Registration var newCookie []byte if regs, newCookie, err = c.rp.Discover(ctx, ns, limit, cookie); err == nil { - if cache.cachedRecs == nil { - cache.cachedRecs = make(map[peer.ID]*record) - } for _, reg := range regs { rec := &record{peer: reg.Peer, expire: int64(reg.Ttl) + currentTime} - cache.cachedRecs[rec.peer.ID] = rec + cache.recs[rec.peer.ID] = rec } cache.cookie = newCookie } } // Randomize and fill channel with available records - count := len(cache.cachedRecs) + count := len(cache.recs) if limit < count { count = limit } @@ -121,7 +129,7 @@ func (c *rendezvousDiscovery) FindPeers(ctx context.Context, ns string, opts ... chPeer := make(chan peer.AddrInfo, count) c.rngMux.Lock() - perm := c.rng.Perm(len(cache.cachedRecs))[0:count] + perm := c.rng.Perm(len(cache.recs))[0:count] c.rngMux.Unlock() permSet := make(map[int]int) @@ -131,9 +139,9 @@ func (c *rendezvousDiscovery) FindPeers(ctx context.Context, ns string, opts ... sendLst := make([]*peer.AddrInfo, count) iter := 0 - for k := range cache.cachedRecs { + for k := range cache.recs { if sendIndex, ok := permSet[iter]; ok { - sendLst[sendIndex] = &cache.cachedRecs[k].peer + sendLst[sendIndex] = &cache.recs[k].peer } iter++ } diff --git a/discovery_test.go b/discovery_test.go index 4091ef7..72161ce 100644 --- a/discovery_test.go +++ b/discovery_test.go @@ -6,7 +6,6 @@ import ( "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" "math/rand" - "sync" "testing" "time" ) @@ -17,7 +16,7 @@ func getRendezvousDiscovery(hosts []host.Host) []discovery.Discovery { for i, h := range hosts[1:] { rp := NewRendezvousPoint(h, rendezvousPeer) rng := rand.New(rand.NewSource(int64(i))) - clients[i] = &rendezvousDiscovery{rp: rp, peerCache: sync.Map{}, rng: rng} + clients[i] = &rendezvousDiscovery{rp: rp, peerCache: make(map[string]*discoveryCache), rng: rng} } return clients }