diff --git a/client.go b/client.go index e04745e..1b01c31 100644 --- a/client.go +++ b/client.go @@ -18,9 +18,10 @@ import ( var log = logging.Logger("rendezvous") type Rendezvous interface { - Register(ctx context.Context, ns string, ttl int) error + RegisterOnce(ctx context.Context, ns string, ttl int) error + Register(ctx context.Context, ns string) error Unregister(ctx context.Context, ns string) error - DiscoverOnce(ctx context.Context, ns string, limit int) ([]pstore.PeerInfo, error) + DiscoverOnce(ctx context.Context, ns string, limit int, cookie []byte) ([]pstore.PeerInfo, []byte, error) Discover(ctx context.Context, ns string) (<-chan pstore.PeerInfo, error) } @@ -36,18 +37,22 @@ type client struct { rp peer.ID } -func (cli *client) Register(ctx context.Context, ns string, ttl int) error { +func (cli *client) RegisterOnce(ctx context.Context, ns string, ttl int) error { s, err := cli.host.NewStream(ctx, cli.rp, RendezvousProto) if err != nil { return err } defer s.Close() + return cli.registerOnce(ctx, ns, ttl, s) +} + +func (cli *client) registerOnce(ctx context.Context, ns string, ttl int, s inet.Stream) error { r := ggio.NewDelimitedReader(s, 1<<20) w := ggio.NewDelimitedWriter(s) req := newRegisterMessage(ns, pstore.PeerInfo{ID: cli.host.ID(), Addrs: cli.host.Addrs()}, ttl) - err = w.WriteMsg(req) + err := w.WriteMsg(req) if err != nil { return err } @@ -70,6 +75,36 @@ func (cli *client) Register(ctx context.Context, ns string, ttl int) error { return nil } +func (cli *client) Register(ctx context.Context, ns string) error { + s, err := cli.host.NewStream(ctx, cli.rp, RendezvousProto) + if err != nil { + return err + } + + go cli.doRegister(ctx, ns, s) + return nil +} + +func (cli *client) doRegister(ctx context.Context, ns string, s inet.Stream) { + const ttl = 2 * 3600 // 2hr + const refresh = ttl - 30 + + defer s.Close() + for { + err := cli.registerOnce(ctx, ns, ttl, s) + if err != nil { + log.Errorf("Error registering: %s", err.Error()) + return + } + + select { + case <-time.After(refresh * time.Second): + case <-ctx.Done(): + return + } + } +} + func (cli *client) Unregister(ctx context.Context, ns string) error { s, err := cli.host.NewStream(ctx, cli.rp, RendezvousProto) if err != nil { @@ -82,30 +117,34 @@ func (cli *client) Unregister(ctx context.Context, ns string) error { return w.WriteMsg(req) } -func (cli *client) DiscoverOnce(ctx context.Context, ns string, limit int) ([]pstore.PeerInfo, error) { +func (cli *client) DiscoverOnce(ctx context.Context, ns string, limit int, cookie []byte) ([]pstore.PeerInfo, []byte, error) { s, err := cli.host.NewStream(ctx, cli.rp, RendezvousProto) if err != nil { - return nil, err + return nil, nil, err } defer s.Close() + return cli.discoverOnce(ctx, ns, limit, cookie, s) +} + +func (cli *client) discoverOnce(ctx context.Context, ns string, limit int, cookie []byte, s inet.Stream) ([]pstore.PeerInfo, []byte, error) { r := ggio.NewDelimitedReader(s, 1<<20) w := ggio.NewDelimitedWriter(s) - req := newDiscoverMessage(ns, limit) - err = w.WriteMsg(req) + req := newDiscoverMessage(ns, limit, cookie) + err := w.WriteMsg(req) if err != nil { - return nil, err + return nil, nil, err } var res pb.Message err = r.ReadMsg(&res) if err != nil { - return nil, err + return nil, nil, err } if res.GetType() != pb.Message_DISCOVER_RESPONSE { - return nil, fmt.Errorf("Unexpected response: %s", res.GetType().String()) + return nil, nil, fmt.Errorf("Unexpected response: %s", res.GetType().String()) } regs := res.GetDiscoverResponse().GetRegistrations() @@ -119,7 +158,7 @@ func (cli *client) DiscoverOnce(ctx context.Context, ns string, limit int) ([]ps pinfos = append(pinfos, pi) } - return pinfos, nil + return pinfos, res.GetDiscoverResponse().GetCookie(), nil } func (cli *client) Discover(ctx context.Context, ns string) (<-chan pstore.PeerInfo, error) { @@ -129,58 +168,37 @@ func (cli *client) Discover(ctx context.Context, ns string) (<-chan pstore.PeerI } ch := make(chan pstore.PeerInfo) - go doDiscovery(ctx, ns, s, ch) + go cli.doDiscover(ctx, ns, s, ch) return ch, nil } -func doDiscovery(ctx context.Context, ns string, s inet.Stream, ch chan pstore.PeerInfo) { +func (cli *client) doDiscover(ctx context.Context, ns string, s inet.Stream, ch chan pstore.PeerInfo) { defer s.Close() defer close(ch) const batch = 100 - r := ggio.NewDelimitedReader(s, 1<<20) - w := ggio.NewDelimitedWriter(s) - - req := newDiscoverMessage(ns, batch) - + var ( + cookie []byte + pi []pstore.PeerInfo + err error + ) for { - err := w.WriteMsg(req) + pi, cookie, err = cli.discoverOnce(ctx, ns, batch, cookie, s) if err != nil { - log.Errorf("Error sending Discover request: %s", err.Error()) + log.Errorf("Error in discovery: %s", err.Error()) return } - var res pb.Message - err = r.ReadMsg(&res) - if err != nil { - log.Errorf("Error reading discover response: %s", err.Error()) - return - } - - if res.GetType() != pb.Message_DISCOVER_RESPONSE { - log.Errorf("Unexpected response: %s", res.GetType().String()) - return - } - - regs := res.GetDiscoverResponse().GetRegistrations() - for _, reg := range regs { - pinfo, err := pbToPeerInfo(reg.GetPeer()) - if err != nil { - log.Errorf("Invalid peer info: %s", err.Error()) - continue - } - + for _, p := range pi { select { - case ch <- pinfo: + case ch <- p: case <-ctx.Done(): return } } - req.Discover.Cookie = res.GetDiscoverResponse().GetCookie() - - if len(regs) < batch { + if len(pi) < batch { select { case <-time.After(1 * time.Minute): case <-ctx.Done(): diff --git a/proto.go b/proto.go index 9021c5c..2ca1d3e 100644 --- a/proto.go +++ b/proto.go @@ -44,7 +44,7 @@ func newUnregisterMessage(ns string, pid peer.ID) *pb.Message { return msg } -func newDiscoverMessage(ns string, limit int) *pb.Message { +func newDiscoverMessage(ns string, limit int, cookie []byte) *pb.Message { msg := new(pb.Message) msg.Type = pb.Message_DISCOVER.Enum() msg.Discover = new(pb.Message_Discover) @@ -55,6 +55,9 @@ func newDiscoverMessage(ns string, limit int) *pb.Message { limit64 := int64(limit) msg.Discover.Limit = &limit64 } + if cookie != nil { + msg.Discover.Cookie = cookie + } return msg }