diff --git a/floodsub.go b/floodsub.go index c52d0b3..cdedce3 100644 --- a/floodsub.go +++ b/floodsub.go @@ -31,16 +31,16 @@ type PubSub struct { publish chan *Message // addSub is a control channel for us to add and remove subscriptions - addSub chan *addSub + addSub chan *addSubReq - // + // get list of topics we are subscribed to getTopics chan *topicReq - // + // get chan of peers we are connected to getPeers chan *listPeerReq - // - addFeedHook chan *addFeedReq + // send subscription here to cancel it + cancelCh chan *Subscription // a notification channel for incoming streams from other peers newPeers chan inet.Stream @@ -49,7 +49,7 @@ type PubSub struct { peerDead chan peer.ID // The set of topics we are subscribed to - myTopics map[string][]*clientFeed + myTopics map[string]map[*Subscription]struct{} // topics tracks which topics each of our peers are subscribed to topics map[string]map[peer.ID]struct{} @@ -83,11 +83,11 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub { publish: make(chan *Message), newPeers: make(chan inet.Stream), peerDead: make(chan peer.ID), + cancelCh: make(chan *Subscription), getPeers: make(chan *listPeerReq), - addSub: make(chan *addSub), + addSub: make(chan *addSubReq), getTopics: make(chan *topicReq), - addFeedHook: make(chan *addFeedReq, 32), - myTopics: make(map[string][]*clientFeed), + myTopics: make(map[string]map[*Subscription]struct{}), topics: make(map[string]map[peer.ID]struct{}), peers: make(map[peer.ID]chan *RPC), seenMessages: timecache.NewTimeCache(time.Second * 30), @@ -118,21 +118,6 @@ func (p *PubSub) processLoop(ctx context.Context) { p.peers[pid] = messages - case req := <-p.addFeedHook: - feeds, ok := p.myTopics[req.topic] - - var out chan *Message - if ok { - out = make(chan *Message, 32) - nfeed := &clientFeed{ - out: out, - ctx: req.ctx, - } - - p.myTopics[req.topic] = append(feeds, nfeed) - } - - req.resp <- out case pid := <-p.peerDead: ch, ok := p.peers[pid] if ok { @@ -145,12 +130,16 @@ func (p *PubSub) processLoop(ctx context.Context) { } case treq := <-p.getTopics: var out []string - for t := range p.myTopics { - out = append(out, t) + for t, subs := range p.myTopics { + if len(subs) > 0 { + out = append(out, t) + } } treq.resp <- out + case sub := <-p.cancelCh: + p.handleRemoveSubscription(sub) case sub := <-p.addSub: - p.handleSubscriptionChange(sub) + p.handleAddSubscription(sub) case preq := <-p.getPeers: tmap, ok := p.topics[preq.topic] if preq.topic != "" && !ok { @@ -183,28 +172,51 @@ func (p *PubSub) processLoop(ctx context.Context) { } } -func (p *PubSub) handleSubscriptionChange(sub *addSub) { - subopt := &pb.RPC_SubOpts{ - Topicid: &sub.topic, - Subscribe: &sub.sub, +func (p *PubSub) handleRemoveSubscription(sub *Subscription) { + subs := p.myTopics[sub.topic] + + if subs == nil { + return } - feeds, ok := p.myTopics[sub.topic] - if sub.sub { - if ok { - return - } + sub.err = fmt.Errorf("subscription cancelled by calling sub.Cancel()") + close(sub.ch) + delete(subs, sub) - p.myTopics[sub.topic] = nil - } else { - if !ok { - return - } + if len(subs) == 0 { + p.announce(sub.topic, false) + } +} - for _, f := range feeds { - close(f.out) - } - delete(p.myTopics, sub.topic) +func (p *PubSub) handleAddSubscription(req *addSubReq) { + subs := p.myTopics[req.topic] + + // announce we want this topic + if len(subs) == 0 { + p.announce(req.topic, true) + } + + // make new if not there + if subs == nil { + p.myTopics[req.topic] = make(map[*Subscription]struct{}) + subs = p.myTopics[req.topic] + } + + sub := &Subscription{ + ch: make(chan *Message, 32), + topic: req.topic, + cancelCh: p.cancelCh, + } + + p.myTopics[sub.topic][sub] = struct{}{} + + req.resp <- sub +} + +func (p *PubSub) announce(topic string, sub bool) { + subopt := &pb.RPC_SubOpts{ + Topicid: &topic, + Subscribe: &sub, } out := rpcWithSubs(subopt) @@ -215,26 +227,9 @@ func (p *PubSub) handleSubscriptionChange(sub *addSub) { func (p *PubSub) notifySubs(msg *pb.Message) { for _, topic := range msg.GetTopicIDs() { - var cleanup bool - feeds := p.myTopics[topic] - for _, f := range feeds { - select { - case f.out <- &Message{msg}: - case <-f.ctx.Done(): - close(f.out) - f.out = nil - cleanup = true - } - } - - if cleanup { - out := make([]*clientFeed, 0, len(feeds)) - for _, f := range feeds { - if f.out != nil { - out = append(out, f) - } - } - p.myTopics[topic] = out + subs := p.myTopics[topic] + for f := range subs { + f.ch <- &Message{msg} } } } @@ -337,22 +332,36 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error { return nil } -type addSub struct { +type addSubReq struct { topic string - sub bool - resp chan chan *Message + resp chan *Subscription } -func (p *PubSub) Subscribe(ctx context.Context, topic string) (<-chan *Message, error) { - err := p.AddTopicSubscription(&pb.TopicDescriptor{ +func (p *PubSub) Subscribe(topic string) (*Subscription, error) { + td := &pb.TopicDescriptor{ Name: proto.String(topic), - }) - - if err != nil { - return nil, err } - return p.GetFeed(ctx, topic) + if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE { + return nil, fmt.Errorf("Auth method not yet supported") + } + + if td.GetEnc().GetMode() != pb.TopicDescriptor_EncOpts_NONE { + return nil, fmt.Errorf("Encryption method not yet supported") + } + + out := make(chan *Subscription, 1) + p.addSub <- &addSubReq{ + topic: topic, + resp: out, + } + + resp := <-out + if resp == nil { + return nil, fmt.Errorf("not subscribed to topic %s", topic) + } + + return resp, nil } type topicReq struct { @@ -365,56 +374,6 @@ func (p *PubSub) GetTopics() []string { return <-out } -func (p *PubSub) AddTopicSubscription(td *pb.TopicDescriptor) error { - if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE { - return fmt.Errorf("Auth method not yet supported") - } - - if td.GetEnc().GetMode() != pb.TopicDescriptor_EncOpts_NONE { - return fmt.Errorf("Encryption method not yet supported") - } - - p.addSub <- &addSub{ - topic: td.GetName(), - sub: true, - } - - return nil -} - -type addFeedReq struct { - ctx context.Context - topic string - resp chan chan *Message -} - -type clientFeed struct { - out chan *Message - ctx context.Context -} - -func (p *PubSub) GetFeed(ctx context.Context, topic string) (<-chan *Message, error) { - out := make(chan chan *Message, 1) - p.addFeedHook <- &addFeedReq{ - ctx: ctx, - topic: topic, - resp: out, - } - - resp := <-out - if resp == nil { - return nil, fmt.Errorf("not subscribed to topic %s", topic) - } - return resp, nil -} - -func (p *PubSub) Unsub(topic string) { - p.addSub <- &addSub{ - topic: topic, - sub: false, - } -} - func (p *PubSub) Publish(topic string, data []byte) error { seqno := make([]byte, 8) binary.BigEndian.PutUint64(seqno, uint64(time.Now().UnixNano())) diff --git a/floodsub_test.go b/floodsub_test.go index dc87cc2..bac78bc 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -14,7 +14,7 @@ import ( netutil "github.com/libp2p/go-libp2p/p2p/test/util" ) -func checkMessageRouting(t *testing.T, topic string, pubs []*PubSub, subs []<-chan *Message) { +func checkMessageRouting(t *testing.T, topic string, pubs []*PubSub, subs []*Subscription) { data := make([]byte, 16) rand.Read(data) @@ -85,13 +85,14 @@ func getPubsubs(ctx context.Context, hs []host.Host) []*PubSub { return psubs } -func assertReceive(t *testing.T, ch <-chan *Message, exp []byte) { +func assertReceive(t *testing.T, ch *Subscription, exp []byte) { select { - case msg := <-ch: + case msg := <-ch.ch: if !bytes.Equal(msg.GetData(), exp) { t.Fatalf("got wrong message, expected %s but got %s", string(exp), string(msg.GetData())) } case <-time.After(time.Second * 5): + t.Logf("%#v\n", ch) t.Fatal("timed out waiting for message of: ", string(exp)) } } @@ -103,9 +104,9 @@ func TestBasicFloodsub(t *testing.T) { psubs := getPubsubs(ctx, hosts) - var msgs []<-chan *Message + var msgs []*Subscription for _, ps := range psubs { - subch, err := ps.Subscribe(ctx, "foobar") + subch, err := ps.Subscribe("foobar") if err != nil { t.Fatal(err) } @@ -125,8 +126,11 @@ func TestBasicFloodsub(t *testing.T) { psubs[owner].Publish("foobar", msg) - for _, resp := range msgs { - got := <-resp + for _, sub := range msgs { + got, err := sub.Next() + if err != nil { + t.Fatal(sub.err) + } if !bytes.Equal(msg, got.Data) { t.Fatal("got wrong message!") } @@ -149,13 +153,13 @@ func TestMultihops(t *testing.T) { connect(t, hosts[3], hosts[4]) connect(t, hosts[4], hosts[5]) - var msgChs []<-chan *Message + var subs []*Subscription for i := 1; i < 6; i++ { - ch, err := psubs[i].Subscribe(ctx, "foobar") + ch, err := psubs[i].Subscribe("foobar") if err != nil { t.Fatal(err) } - msgChs = append(msgChs, ch) + subs = append(subs, ch) } time.Sleep(time.Millisecond * 100) @@ -168,7 +172,7 @@ func TestMultihops(t *testing.T) { // last node in the chain should get the message select { - case out := <-msgChs[4]: + case out := <-subs[4].ch: if !bytes.Equal(out.GetData(), msg) { t.Fatal("got wrong data") } @@ -188,12 +192,12 @@ func TestReconnects(t *testing.T) { connect(t, hosts[0], hosts[1]) connect(t, hosts[0], hosts[2]) - A, err := psubs[1].Subscribe(ctx, "cats") + A, err := psubs[1].Subscribe("cats") if err != nil { t.Fatal(err) } - B, err := psubs[2].Subscribe(ctx, "cats") + B, err := psubs[2].Subscribe("cats") if err != nil { t.Fatal(err) } @@ -209,7 +213,7 @@ func TestReconnects(t *testing.T) { assertReceive(t, A, msg) assertReceive(t, B, msg) - psubs[2].Unsub("cats") + B.Cancel() time.Sleep(time.Millisecond * 50) @@ -221,7 +225,7 @@ func TestReconnects(t *testing.T) { assertReceive(t, A, msg2) select { - case _, ok := <-B: + case _, ok := <-B.ch: if ok { t.Fatal("shouldnt have gotten data on this channel") } @@ -229,12 +233,17 @@ func TestReconnects(t *testing.T) { t.Fatal("timed out waiting for B chan to be closed") } - ch2, err := psubs[2].Subscribe(ctx, "cats") + nSubs := len(psubs[2].myTopics["cats"]) + if nSubs > 0 { + t.Fatal(`B should have 0 subscribers for channel "cats", has`, nSubs) + } + + ch2, err := psubs[2].Subscribe("cats") if err != nil { t.Fatal(err) } - time.Sleep(time.Millisecond * 50) + time.Sleep(time.Millisecond * 100) nextmsg := []byte("ifps is kul") err = psubs[0].Publish("cats", nextmsg) @@ -254,7 +263,7 @@ func TestNoConnection(t *testing.T) { psubs := getPubsubs(ctx, hosts) - ch, err := psubs[5].Subscribe(ctx, "foobar") + ch, err := psubs[5].Subscribe("foobar") if err != nil { t.Fatal(err) } @@ -265,7 +274,7 @@ func TestNoConnection(t *testing.T) { } select { - case <-ch: + case <-ch.ch: t.Fatal("shouldnt have gotten a message") case <-time.After(time.Millisecond * 200): } @@ -288,7 +297,7 @@ func TestSelfReceive(t *testing.T) { time.Sleep(time.Millisecond * 10) - ch, err := psub.Subscribe(ctx, "foobar") + ch, err := psub.Subscribe("foobar") if err != nil { t.Fatal(err) } @@ -311,14 +320,14 @@ func TestOneToOne(t *testing.T) { connect(t, hosts[0], hosts[1]) - ch, err := psubs[1].Subscribe(ctx, "foobar") + ch, err := psubs[1].Subscribe("foobar") if err != nil { t.Fatal(err) } time.Sleep(time.Millisecond * 50) - checkMessageRouting(t, "foobar", psubs, []<-chan *Message{ch}) + checkMessageRouting(t, "foobar", psubs, []*Subscription{ch}) } func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) { @@ -362,9 +371,9 @@ func TestTreeTopology(t *testing.T) { [8] -> [9] */ - var chs []<-chan *Message + var chs []*Subscription for _, ps := range psubs { - ch, err := ps.Subscribe(ctx, "fizzbuzz") + ch, err := ps.Subscribe("fizzbuzz") if err != nil { t.Fatal(err) } @@ -404,31 +413,31 @@ func TestSubReporting(t *testing.T) { host := getNetHosts(t, ctx, 1)[0] psub := NewFloodSub(ctx, host) - _, err := psub.Subscribe(ctx, "foo") + fooSub, err := psub.Subscribe("foo") if err != nil { t.Fatal(err) } - _, err = psub.Subscribe(ctx, "bar") + barSub, err := psub.Subscribe("bar") if err != nil { t.Fatal(err) } assertHasTopics(t, psub, "foo", "bar") - _, err = psub.Subscribe(ctx, "baz") + _, err = psub.Subscribe("baz") if err != nil { t.Fatal(err) } assertHasTopics(t, psub, "foo", "bar", "baz") - psub.Unsub("bar") + barSub.Cancel() assertHasTopics(t, psub, "foo", "baz") - psub.Unsub("foo") + fooSub.Cancel() assertHasTopics(t, psub, "baz") - _, err = psub.Subscribe(ctx, "fish") + _, err = psub.Subscribe("fish") if err != nil { t.Fatal(err) } diff --git a/subscription.go b/subscription.go new file mode 100644 index 0000000..6fd01fb --- /dev/null +++ b/subscription.go @@ -0,0 +1,26 @@ +package floodsub + +type Subscription struct { + topic string + ch chan *Message + cancelCh chan<- *Subscription + err error +} + +func (sub *Subscription) Topic() string { + return sub.topic +} + +func (sub *Subscription) Next() (*Message, error) { + msg, ok := <-sub.ch + + if !ok { + return msg, sub.err + } + + return msg, nil +} + +func (sub *Subscription) Cancel() { + sub.cancelCh <- sub +}