From 28f2c2f0946b5bba196c45952d25566390cbb5f1 Mon Sep 17 00:00:00 2001 From: Jeromy Date: Wed, 14 Sep 2016 15:11:41 -0700 Subject: [PATCH] add way to query subscribed topics --- floodsub.go | 22 ++++++++++++++++++- floodsub_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/floodsub.go b/floodsub.go index 3e3b7bd..71b6f61 100644 --- a/floodsub.go +++ b/floodsub.go @@ -33,6 +33,9 @@ type PubSub struct { // addSub is a control channel for us to add and remove subscriptions addSub chan *addSub + // + getTopics chan *topicReq + // a notification channel for incoming streams from other peers newPeers chan inet.Stream @@ -75,6 +78,7 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub { newPeers: make(chan inet.Stream), peerDead: make(chan peer.ID), addSub: make(chan *addSub), + getTopics: make(chan *topicReq), myTopics: make(map[string]chan *Message), topics: make(map[string]map[peer.ID]struct{}), peers: make(map[peer.ID]chan *RPC), @@ -112,6 +116,12 @@ func (p *PubSub) processLoop(ctx context.Context) { for _, t := range p.topics { delete(t, pid) } + case treq := <-p.getTopics: + var out []string + for t := range p.myTopics { + out = append(out, t) + } + treq.resp <- out case sub := <-p.addSub: p.handleSubscriptionChange(sub) case rpc := <-p.incoming: @@ -270,12 +280,22 @@ type addSub struct { resp chan chan *Message } -func (p *PubSub) Subscribe(topic string) (<-chan *Message, error) { +func (p *PubSub) Subscribe(ctx context.Context, topic string) (<-chan *Message, error) { return p.SubscribeComplicated(&pb.TopicDescriptor{ Name: proto.String(topic), }) } +type topicReq struct { + resp chan []string +} + +func (p *PubSub) GetTopics() []string { + out := make(chan []string, 1) + p.getTopics <- &topicReq{resp: out} + return <-out +} + func (p *PubSub) SubscribeComplicated(td *pb.TopicDescriptor) (<-chan *Message, error) { if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE { return nil, fmt.Errorf("Auth method not yet supported") diff --git a/floodsub_test.go b/floodsub_test.go index 86f1b6a..9155b47 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "math/rand" + "sort" "testing" "time" @@ -360,3 +361,58 @@ func TestTreeTopology(t *testing.T) { checkMessageRouting(t, "fizzbuzz", []*PubSub{psubs[9], psubs[3]}, chs) } + +func assertHasTopics(t *testing.T, ps *PubSub, exptopics ...string) { + topics := ps.GetTopics() + sort.Strings(topics) + sort.Strings(exptopics) + + if len(topics) != len(exptopics) { + t.Fatalf("expected to have %v, but got %v", exptopics, topics) + } + + for i, v := range exptopics { + if topics[i] != v { + t.Fatalf("expected %s but have %s", v, topics[i]) + } + } +} + +func TestSubReporting(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + host := getNetHosts(t, ctx, 1)[0] + psub := NewFloodSub(ctx, host) + + _, err := psub.Subscribe(ctx, "foo") + if err != nil { + t.Fatal(err) + } + + _, err = psub.Subscribe(ctx, "bar") + if err != nil { + t.Fatal(err) + } + + assertHasTopics(t, psub, "foo", "bar") + + _, err = psub.Subscribe(ctx, "baz") + if err != nil { + t.Fatal(err) + } + + assertHasTopics(t, psub, "foo", "bar", "baz") + + psub.Unsub("bar") + assertHasTopics(t, psub, "foo", "baz") + psub.Unsub("foo") + assertHasTopics(t, psub, "baz") + + _, err = psub.Subscribe(ctx, "fish") + if err != nil { + t.Fatal(err) + } + + assertHasTopics(t, psub, "baz", "fish") +}