diff --git a/floodsub.go b/floodsub.go index 4a3e8cc..c52d0b3 100644 --- a/floodsub.go +++ b/floodsub.go @@ -39,6 +39,9 @@ type PubSub struct { // getPeers chan *listPeerReq + // + addFeedHook chan *addFeedReq + // a notification channel for incoming streams from other peers newPeers chan inet.Stream @@ -46,7 +49,7 @@ type PubSub struct { peerDead chan peer.ID // The set of topics we are subscribed to - myTopics map[string]chan *Message + myTopics map[string][]*clientFeed // topics tracks which topics each of our peers are subscribed to topics map[string]map[peer.ID]struct{} @@ -83,7 +86,8 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub { getPeers: make(chan *listPeerReq), addSub: make(chan *addSub), getTopics: make(chan *topicReq), - myTopics: make(map[string]chan *Message), + addFeedHook: make(chan *addFeedReq, 32), + myTopics: make(map[string][]*clientFeed), topics: make(map[string]map[peer.ID]struct{}), peers: make(map[peer.ID]chan *RPC), seenMessages: timecache.NewTimeCache(time.Second * 30), @@ -114,6 +118,21 @@ func (p *PubSub) processLoop(ctx context.Context) { p.peers[pid] = messages + case req := <-p.addFeedHook: + feeds, ok := p.myTopics[req.topic] + + var out chan *Message + if ok { + out = make(chan *Message, 32) + nfeed := &clientFeed{ + out: out, + ctx: req.ctx, + } + + p.myTopics[req.topic] = append(feeds, nfeed) + } + + req.resp <- out case pid := <-p.peerDead: ch, ok := p.peers[pid] if ok { @@ -170,23 +189,21 @@ func (p *PubSub) handleSubscriptionChange(sub *addSub) { Subscribe: &sub.sub, } - ch, ok := p.myTopics[sub.topic] + feeds, ok := p.myTopics[sub.topic] if sub.sub { if ok { - // we don't allow multiple subs per topic at this point - sub.resp <- nil return } - resp := make(chan *Message, 16) - p.myTopics[sub.topic] = resp - sub.resp <- resp + p.myTopics[sub.topic] = nil } else { if !ok { return } - close(ch) + for _, f := range feeds { + close(f.out) + } delete(p.myTopics, sub.topic) } @@ -198,9 +215,26 @@ func (p *PubSub) handleSubscriptionChange(sub *addSub) { func (p *PubSub) notifySubs(msg *pb.Message) { for _, topic := range msg.GetTopicIDs() { - subch, ok := p.myTopics[topic] - if ok { - subch <- &Message{msg} + var cleanup bool + feeds := p.myTopics[topic] + for _, f := range feeds { + select { + case f.out <- &Message{msg}: + case <-f.ctx.Done(): + close(f.out) + f.out = nil + cleanup = true + } + } + + if cleanup { + out := make([]*clientFeed, 0, len(feeds)) + for _, f := range feeds { + if f.out != nil { + out = append(out, f) + } + } + p.myTopics[topic] = out } } } @@ -310,9 +344,15 @@ type addSub struct { } func (p *PubSub) Subscribe(ctx context.Context, topic string) (<-chan *Message, error) { - return p.SubscribeComplicated(&pb.TopicDescriptor{ + err := p.AddTopicSubscription(&pb.TopicDescriptor{ Name: proto.String(topic), }) + + if err != nil { + return nil, err + } + + return p.GetFeed(ctx, topic) } type topicReq struct { @@ -325,28 +365,47 @@ func (p *PubSub) GetTopics() []string { return <-out } -func (p *PubSub) SubscribeComplicated(td *pb.TopicDescriptor) (<-chan *Message, error) { +func (p *PubSub) AddTopicSubscription(td *pb.TopicDescriptor) error { if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE { - return nil, fmt.Errorf("Auth method not yet supported") + return fmt.Errorf("Auth method not yet supported") } if td.GetEnc().GetMode() != pb.TopicDescriptor_EncOpts_NONE { - return nil, fmt.Errorf("Encryption method not yet supported") + return fmt.Errorf("Encryption method not yet supported") } - resp := make(chan chan *Message) p.addSub <- &addSub{ topic: td.GetName(), - resp: resp, sub: true, } - outch := <-resp - if outch == nil { - return nil, fmt.Errorf("error, duplicate subscription") + return nil +} + +type addFeedReq struct { + ctx context.Context + topic string + resp chan chan *Message +} + +type clientFeed struct { + out chan *Message + ctx context.Context +} + +func (p *PubSub) GetFeed(ctx context.Context, topic string) (<-chan *Message, error) { + out := make(chan chan *Message, 1) + p.addFeedHook <- &addFeedReq{ + ctx: ctx, + topic: topic, + resp: out, } - return outch, nil + resp := <-out + if resp == nil { + return nil, fmt.Errorf("not subscribed to topic %s", topic) + } + return resp, nil } func (p *PubSub) Unsub(topic string) {