diff --git a/floodsub_test.go b/floodsub_test.go index 47df0f5..9cb8dfc 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -1064,7 +1064,72 @@ func TestImproperlySignedMessageRejected(t *testing.T) { } } -func TestSubscriptionNotification(t *testing.T) { +func TestSubscriptionJoinNotification(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const numLateSubscribers = 10 + const numHosts = 20 + hosts := getNetHosts(t, ctx, numHosts) + + psubs := getPubsubs(ctx, hosts) + + msgs := make([]*Subscription, numHosts) + subPeersFound := make([]map[peer.ID]struct{}, numHosts) + + // Have some peers subscribe earlier than other peers. + // This exercises whether we get subscription notifications from + // existing peers. + for i, ps := range psubs[numLateSubscribers:] { + subch, err := ps.Subscribe("foobar") + if err != nil { + t.Fatal(err) + } + + msgs[i] = subch + } + + connectAll(t, hosts) + + time.Sleep(time.Millisecond * 100) + + // Have the rest subscribe + for i, ps := range psubs[:numLateSubscribers] { + subch, err := ps.Subscribe("foobar") + if err != nil { + t.Fatal(err) + } + + msgs[i+numLateSubscribers] = subch + } + + wg := sync.WaitGroup{} + for i := 0; i < numHosts; i++ { + peersFound := make(map[peer.ID]struct{}) + subPeersFound[i] = peersFound + sub := msgs[i] + wg.Add(1) + go func(peersFound map[peer.ID]struct{}) { + defer wg.Done() + for i := 0; i < numHosts-1; i++ { + pid, err := sub.NextPeerJoin(ctx) + if err != nil { + t.Fatal(err) + } + peersFound[pid] = struct{}{} + } + }(peersFound) + } + + wg.Wait() + for _, peersFound := range subPeersFound { + if len(peersFound) != numHosts-1 { + t.Fatal("incorrect number of peers found") + } + } +} + +func TestSubscriptionLeaveNotification(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -1076,7 +1141,7 @@ func TestSubscriptionNotification(t *testing.T) { msgs := make([]*Subscription, numHosts) subPeersFound := make([]map[peer.ID]struct{}, numHosts) - wg := sync.WaitGroup{} + // Subscribe all peers and wait until they've all been found for i, ps := range psubs { subch, err := ps.Subscribe("foobar") if err != nil { @@ -1084,13 +1149,22 @@ func TestSubscriptionNotification(t *testing.T) { } msgs[i] = subch + } + + connectAll(t, hosts) + + time.Sleep(time.Millisecond * 100) + + wg := sync.WaitGroup{} + for i := 0; i < numHosts; i++ { peersFound := make(map[peer.ID]struct{}) subPeersFound[i] = peersFound + sub := msgs[i] wg.Add(1) go func(peersFound map[peer.ID]struct{}) { defer wg.Done() for i := 0; i < numHosts-1; i++ { - pid, err := subch.NextSubscribedPeer(ctx) + pid, err := sub.NextPeerJoin(ctx) if err != nil { t.Fatal(err) } @@ -1099,14 +1173,34 @@ func TestSubscriptionNotification(t *testing.T) { }(peersFound) } - connectAll(t, hosts) - - time.Sleep(time.Millisecond * 100) - wg.Wait() for _, peersFound := range subPeersFound { if len(peersFound) != numHosts-1 { t.Fatal("incorrect number of peers found") } } + + // Test removing peers and verifying that they cause events + msgs[1].Cancel() + hosts[2].Close() + psubs[0].BlacklistPeer(hosts[3].ID()) + + leavingPeers := make(map[peer.ID]struct{}) + for i := 0; i < 3; i++ { + pid, err := msgs[0].NextPeerLeave(ctx) + if err != nil { + t.Fatal(err) + } + leavingPeers[pid] = struct{}{} + } + + if _, ok := leavingPeers[hosts[1].ID()]; !ok { + t.Fatal(fmt.Errorf("canceling subscription did not cause a leave event")) + } + if _, ok := leavingPeers[hosts[2].ID()]; !ok { + t.Fatal(fmt.Errorf("closing host did not cause a leave event")) + } + if _, ok := leavingPeers[hosts[3].ID()]; !ok { + t.Fatal(fmt.Errorf("blacklisting peer did not cause a leave event")) + } } diff --git a/pubsub.go b/pubsub.go index 2be2da7..dd82c43 100644 --- a/pubsub.go +++ b/pubsub.go @@ -336,8 +336,9 @@ func (p *PubSub) processLoop(ctx context.Context) { } delete(p.peers, pid) - for _, t := range p.topics { - delete(t, pid) + for t, tmap := range p.topics { + delete(tmap, pid) + p.notifySubscriberLeft(t, pid) } p.rt.RemovePeer(pid) @@ -395,8 +396,9 @@ func (p *PubSub) processLoop(ctx context.Context) { if ok { close(ch) delete(p.peers, pid) - for _, t := range p.topics { - delete(t, pid) + for t, tmap := range p.topics { + delete(tmap, pid) + p.notifySubscriberLeft(t, pid) } p.rt.RemovePeer(pid) } @@ -422,6 +424,7 @@ func (p *PubSub) handleRemoveSubscription(sub *Subscription) { sub.err = fmt.Errorf("subscription cancelled by calling sub.Cancel()") close(sub.ch) close(sub.inboundSubs) + close(sub.leavingSubs) delete(subs, sub) if len(subs) == 0 { @@ -451,10 +454,21 @@ func (p *PubSub) handleAddSubscription(req *addSubReq) { subs = p.myTopics[sub.topic] } - sub.inboundSubs = make(chan peer.ID, 32) + tmap := p.topics[sub.topic] + inboundBufSize := len(tmap) + if inboundBufSize < 32 { + inboundBufSize = 32 + } + sub.ch = make(chan *Message, 32) + sub.inboundSubs = make(chan peer.ID, inboundBufSize) + sub.leavingSubs = make(chan peer.ID, 32) sub.cancelCh = p.cancelCh + for pid := range tmap { + sub.inboundSubs <- pid + } + p.myTopics[sub.topic][sub] = struct{}{} req.resp <- sub @@ -565,6 +579,18 @@ func (p *PubSub) subscribedToMsg(msg *pb.Message) bool { return false } +func (p *PubSub) notifySubscriberLeft(topic string, pid peer.ID) { + if subs, ok := p.myTopics[topic]; ok { + for s := range subs { + select { + case s.leavingSubs <- pid: + default: + log.Infof("Can't deliver leave event to subscription for topic %s; subscriber too slow", topic) + } + } + } +} + func (p *PubSub) handleIncomingRPC(rpc *RPC) { for _, subopt := range rpc.GetSubscriptions() { t := subopt.GetTopicid() @@ -583,7 +609,7 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) { select { case s.inboundSubs <- inboundPeer: default: - log.Infof("Can't deliver event to subscription for topic %s; subscriber too slow", t) + log.Infof("Can't deliver join event to subscription for topic %s; subscriber too slow", t) } } } @@ -594,6 +620,7 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) { continue } delete(tmap, rpc.from) + p.notifySubscriberLeft(t, rpc.from) } } diff --git a/subscription.go b/subscription.go index ad70778..61f6e41 100644 --- a/subscription.go +++ b/subscription.go @@ -10,6 +10,7 @@ type Subscription struct { ch chan *Message cancelCh chan<- *Subscription inboundSubs chan peer.ID + leavingSubs chan peer.ID err error } @@ -34,7 +35,7 @@ func (sub *Subscription) Cancel() { sub.cancelCh <- sub } -func (sub *Subscription) NextSubscribedPeer(ctx context.Context) (peer.ID, error) { +func (sub *Subscription) NextPeerJoin(ctx context.Context) (peer.ID, error) { select { case newPeer, ok := <-sub.inboundSubs: if !ok { @@ -46,3 +47,16 @@ func (sub *Subscription) NextSubscribedPeer(ctx context.Context) (peer.ID, error return "", ctx.Err() } } + +func (sub *Subscription) NextPeerLeave(ctx context.Context) (peer.ID, error) { + select { + case leavingPeer, ok := <-sub.leavingSubs: + if !ok { + return leavingPeer, sub.err + } + + return leavingPeer, nil + case <-ctx.Done(): + return "", ctx.Err() + } +}