diff --git a/client.go b/client.go index 09ec733..e416433 100644 --- a/client.go +++ b/client.go @@ -20,7 +20,7 @@ var ( ) type RendezvousPoint interface { - Register(ctx context.Context, ns string, ttl int) error + Register(ctx context.Context, ns string, ttl int) (time.Duration, error) Unregister(ctx context.Context, ns string) error Discover(ctx context.Context, ns string, limit int, cookie []byte) ([]Registration, []byte, error) DiscoverAsync(ctx context.Context, ns string) (<-chan Registration, error) @@ -33,7 +33,7 @@ type Registration struct { } type RendezvousClient interface { - Register(ctx context.Context, ns string, ttl int) error + Register(ctx context.Context, ns string, ttl int) (time.Duration, error) Unregister(ctx context.Context, ns string) error Discover(ctx context.Context, ns string, limit int, cookie []byte) ([]peer.AddrInfo, []byte, error) DiscoverAsync(ctx context.Context, ns string) (<-chan peer.AddrInfo, error) @@ -63,7 +63,7 @@ type rendezvousClient struct { rp RendezvousPoint } -func (rp *rendezvousPoint) Register(ctx context.Context, ns string, ttl int) error { +func (rp *rendezvousPoint) Register(ctx context.Context, ns string, ttl int) (time.Duration, error) { s, err := rp.host.NewStream(ctx, rp.p, RendezvousProto) if err != nil { return err @@ -76,39 +76,40 @@ func (rp *rendezvousPoint) Register(ctx context.Context, ns string, ttl int) err req := newRegisterMessage(ns, peer.AddrInfo{ID: rp.host.ID(), Addrs: rp.host.Addrs()}, ttl) err = w.WriteMsg(req) if err != nil { - return err + return 0, err } var res pb.Message err = r.ReadMsg(&res) if err != nil { - return err + return 0, err } if res.GetType() != pb.Message_REGISTER_RESPONSE { - return fmt.Errorf("Unexpected response: %s", res.GetType().String()) + return 0, fmt.Errorf("Unexpected response: %s", res.GetType().String()) } - status := res.GetRegisterResponse().GetStatus() + response := res.GetRegisterResponse() + status := response.GetStatus() if status != pb.Message_OK { - return RendezvousError{Status: status, Text: res.GetRegisterResponse().GetStatusText()} + return 0, RendezvousError{Status: status, Text: res.GetRegisterResponse().GetStatusText()} } - return nil + return time.Duration(*response.Ttl) * time.Second, nil } -func (rc *rendezvousClient) Register(ctx context.Context, ns string, ttl int) error { +func (rc *rendezvousClient) Register(ctx context.Context, ns string, ttl int) (time.Duration, error) { if ttl < 120 { - return fmt.Errorf("registration TTL is too short") + return 0, fmt.Errorf("registration TTL is too short") } - err := rc.rp.Register(ctx, ns, ttl) + returnedTTL, err := rc.rp.Register(ctx, ns, ttl) if err != nil { - return err + return 0, err } go registerRefresh(ctx, rc.rp, ns, ttl) - return nil + return returnedTTL, nil } func registerRefresh(ctx context.Context, rz RendezvousPoint, ns string, ttl int) { diff --git a/discovery_client.go b/discovery_client.go index 3a25253..30b9ffd 100644 --- a/discovery_client.go +++ b/discovery_client.go @@ -19,7 +19,8 @@ type rendezvousDiscoveryClient struct { } type discoveredPeerCache struct { - cachedRegs []Registration + cachedRegs map[peer.ID]*Registration + cookie []byte mux sync.Mutex } @@ -39,19 +40,17 @@ func (c *rendezvousDiscoveryClient) Advertise(ctx context.Context, ns string, op ttl := options.Ttl var ttlSeconds int - // Default is minimum duration if ttl == 0 { - ttlSeconds = 120 + ttlSeconds = 7200 } else { ttlSeconds = int(math.Round(ttl.Seconds())) } - if err := c.rp.Register(ctx, ns, ttlSeconds); err != nil { + if returnedTTL, err := c.rp.Register(ctx, ns, ttlSeconds); err != nil { return 0, err + } else { + return returnedTTL, nil } - - actualTTL := time.Duration(ttlSeconds) * time.Second - return actualTTL, nil } func (c *rendezvousDiscoveryClient) FindPeers(ctx context.Context, ns string, opts ...discovery.Option) (<-chan peer.AddrInfo, error) { @@ -75,34 +74,36 @@ func (c *rendezvousDiscoveryClient) FindPeers(ctx context.Context, ns string, op cache = genericCache.(*discoveredPeerCache) cache.mux.Lock() - cachedRegs := cache.cachedRegs // Remove all expired entries from cache currentTime := int(time.Now().Unix()) - newCacheSize := len(cachedRegs) + newCacheSize := len(cache.cachedRegs) - for i := 0; i < newCacheSize; i++ { - reg := cachedRegs[i] + for p := range cache.cachedRegs { + reg := cache.cachedRegs[p] if reg.Ttl < currentTime { newCacheSize-- - if i != newCacheSize { - cachedRegs[i] = cachedRegs[newCacheSize] - i-- - } + delete(cache.cachedRegs, p) } } - cache.cachedRegs = cachedRegs[:newCacheSize] + + cookie := cache.cookie cache.mux.Unlock() // Discover new records if we don't have enough var discoveryErr error if newCacheSize < limit { - if discoveryRecords, _, err := c.rp.Discover(ctx, ns, limit, nil); err == nil { - for i := range discoveryRecords { - discoveryRecords[i].Ttl += currentTime - } + if discoveryRecords, newCookie, err := c.rp.Discover(ctx, ns, limit, cookie); err == nil { cache.mux.Lock() - cache.cachedRegs = discoveryRecords + if cache.cachedRegs == nil { + cache.cachedRegs = make(map[peer.ID]*Registration) + } + for i := range discoveryRecords { + rec := &discoveryRecords[i] + rec.Ttl += currentTime + cache.cachedRegs[rec.Peer.ID] = rec + } + cache.cookie = newCookie cache.mux.Unlock() } else { // TODO: Should we return error even if we have valid cached results? @@ -112,8 +113,7 @@ func (c *rendezvousDiscoveryClient) FindPeers(ctx context.Context, ns string, op // Randomize and fill channel with available records cache.mux.Lock() - cachedRegs = cache.cachedRegs - sendQuantity := len(cachedRegs) + sendQuantity := len(cache.cachedRegs) if limit < sendQuantity { sendQuantity = limit } @@ -121,11 +121,25 @@ func (c *rendezvousDiscoveryClient) FindPeers(ctx context.Context, ns string, op chPeer := make(chan peer.AddrInfo, sendQuantity) c.rngMux.Lock() - perm := c.rng.Perm(len(cachedRegs))[0:sendQuantity] + perm := c.rng.Perm(len(cache.cachedRegs))[0:sendQuantity] c.rngMux.Unlock() - for _, i := range perm { - chPeer <- cachedRegs[i].Peer + permSet := make(map[int]int) + for i, v := range perm { + permSet[v] = i + } + + sendLst := make([]*peer.AddrInfo, sendQuantity) + iter := 0 + for k := range cache.cachedRegs { + if sendIndex, ok := permSet[iter]; ok { + sendLst[sendIndex] = &cache.cachedRegs[k].Peer + } + iter++ + } + + for _, send := range sendLst { + chPeer <- *send } cache.mux.Unlock()