From 686c928d4e2f72182facc294683cc2bd92061b8d Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Sat, 2 Nov 2019 21:12:21 -0400 Subject: [PATCH] pubsub and topic methods now return error if the pubsub context has been cancelled instead of hanging --- pubsub.go | 35 +++++++++++++++++++++++++++++------ topic.go | 27 +++++++++++++++++++++++---- 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/pubsub.go b/pubsub.go index cffb83b..a769a68 100644 --- a/pubsub.go +++ b/pubsub.go @@ -789,9 +789,13 @@ func (p *PubSub) tryJoin(topic string, opts ...TopicOpt) (*Topic, bool, error) { } resp := make(chan *Topic, 1) - t.p.addTopic <- &addTopicReq{ + select { + case t.p.addTopic <- &addTopicReq{ topic: t, resp: resp, + }: + case <-t.p.ctx.Done(): + return nil, false, t.p.ctx.Err() } returnedTopic := <-resp @@ -848,7 +852,11 @@ type topicReq struct { // GetTopics returns the topics this node is subscribed to. func (p *PubSub) GetTopics() []string { out := make(chan []string, 1) - p.getTopics <- &topicReq{resp: out} + select { + case p.getTopics <- &topicReq{resp: out}: + case <-p.ctx.Done(): + return nil + } return <-out } @@ -880,16 +888,23 @@ type listPeerReq struct { // ListPeers returns a list of peers we are connected to in the given topic. func (p *PubSub) ListPeers(topic string) []peer.ID { out := make(chan []peer.ID) - p.getPeers <- &listPeerReq{ + select { + case p.getPeers <- &listPeerReq{ resp: out, topic: topic, + }: + case <-p.ctx.Done(): + return nil } return <-out } // BlacklistPeer blacklists a peer; all messages from this peer will be unconditionally dropped. func (p *PubSub) BlacklistPeer(pid peer.ID) { - p.blacklistPeer <- pid + select { + case p.blacklistPeer <- pid: + case <-p.ctx.Done(): + } } // RegisterTopicValidator registers a validator for topic. @@ -910,7 +925,11 @@ func (p *PubSub) RegisterTopicValidator(topic string, val Validator, opts ...Val } } - p.addVal <- addVal + select { + case p.addVal <- addVal: + case <-p.ctx.Done(): + return p.ctx.Err() + } return <-addVal.resp } @@ -922,6 +941,10 @@ func (p *PubSub) UnregisterTopicValidator(topic string) error { resp: make(chan error, 1), } - p.rmVal <- rmVal + select { + case p.rmVal <- rmVal: + case <-p.ctx.Done(): + return p.ctx.Err() + } return <-rmVal.resp } diff --git a/topic.go b/topic.go index 651ae11..5434791 100644 --- a/topic.go +++ b/topic.go @@ -51,7 +51,9 @@ func (t *Topic) EventHandler(opts ...TopicEventHandlerOpt) (*TopicEventHandler, } done := make(chan struct{}, 1) - t.p.eval <- func() { + + select { + case t.p.eval <- func() { tmap := t.p.topics[t.topic] for p := range tmap { h.evtLog[p] = PeerJoin @@ -61,6 +63,9 @@ func (t *Topic) EventHandler(opts ...TopicEventHandlerOpt) (*TopicEventHandler, t.evtHandlers[h] = struct{}{} t.evtHandlerMux.Unlock() done <- struct{}{} + }: + case <-t.p.ctx.Done(): + return nil, t.p.ctx.Err() } <-done @@ -104,9 +109,13 @@ func (t *Topic) Subscribe(opts ...SubOpt) (*Subscription, error) { t.p.disc.Discover(sub.topic) - t.p.addSub <- &addSubReq{ + select { + case t.p.addSub <- &addSubReq{ sub: sub, resp: out, + }: + case <-t.p.ctx.Done(): + return nil, t.p.ctx.Err() } return <-out, nil @@ -157,7 +166,11 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error t.p.disc.Bootstrap(ctx, t.topic, pub.ready) } - t.p.publish <- &Message{m, id} + select { + case t.p.publish <- &Message{m, id}: + case <-t.p.ctx.Done(): + return t.p.ctx.Err() + } return nil } @@ -181,7 +194,13 @@ func (t *Topic) Close() error { } req := &rmTopicReq{t, make(chan error, 1)} - t.p.rmTopic <- req + + select { + case t.p.rmTopic <- req: + case <-t.p.ctx.Done(): + return t.p.ctx.Err() + } + err := <-req.resp if err == nil {