diff --git a/floodsub_test.go b/floodsub_test.go index 4e227d5..c15c57d 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -1063,3 +1063,45 @@ func TestImproperlySignedMessageRejected(t *testing.T) { ) } } + +func TestSubscriptionNotification(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const numHosts = 20 + hosts := getNetHosts(t, ctx, numHosts) + + psubs := getPubsubs(ctx, hosts) + + msgs := make([]*Subscription, numHosts) + subPeersFound := make([]map[peer.ID]struct{}, numHosts) + for i, ps := range psubs { + subch, err := ps.Subscribe("foobar") + if err != nil { + t.Fatal(err) + } + + msgs[i] = subch + peersFound := make(map[peer.ID]struct{}) + subPeersFound[i] = peersFound + go func(peersFound map[peer.ID]struct{}) { + for i := 0; i < numHosts-1; i++ { + pid, err := subch.NextSubscribedPeer(ctx) + if err != nil { + t.Fatal(err) + } + peersFound[pid] = struct{}{} + } + }(peersFound) + } + + connectAll(t, hosts) + + time.Sleep(time.Millisecond * 100) + + for _, peersFound := range subPeersFound { + if len(peersFound) != numHosts-1 { + t.Fatal("incorrect number of peers found") + } + } +} diff --git a/pubsub.go b/pubsub.go index 6df169d..2be2da7 100644 --- a/pubsub.go +++ b/pubsub.go @@ -78,6 +78,9 @@ type PubSub struct { // topics tracks which topics each of our peers are subscribed to topics map[string]map[peer.ID]struct{} + // a set of notification channels for newly subscribed peers + newSubs map[string]chan peer.ID + // sendMsg handles messages that have been validated sendMsg chan *sendReq @@ -418,6 +421,7 @@ func (p *PubSub) handleRemoveSubscription(sub *Subscription) { sub.err = fmt.Errorf("subscription cancelled by calling sub.Cancel()") close(sub.ch) + close(sub.inboundSubs) delete(subs, sub) if len(subs) == 0 { @@ -447,6 +451,7 @@ func (p *PubSub) handleAddSubscription(req *addSubReq) { subs = p.myTopics[sub.topic] } + sub.inboundSubs = make(chan peer.ID, 32) sub.ch = make(chan *Message, 32) sub.cancelCh = p.cancelCh @@ -570,7 +575,19 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) { p.topics[t] = tmap } - tmap[rpc.from] = struct{}{} + if _, ok = tmap[rpc.from]; !ok { + tmap[rpc.from] = struct{}{} + if subs, ok := p.myTopics[t]; ok { + inboundPeer := rpc.from + for s := range subs { + select { + case s.inboundSubs <- inboundPeer: + default: + log.Infof("Can't deliver event to subscription for topic %s; subscriber too slow", t) + } + } + } + } } else { tmap, ok := p.topics[t] if !ok { diff --git a/subscription.go b/subscription.go index 66a9e51..ad70778 100644 --- a/subscription.go +++ b/subscription.go @@ -2,13 +2,15 @@ package pubsub import ( "context" + "github.com/libp2p/go-libp2p-core/peer" ) type Subscription struct { - topic string - ch chan *Message - cancelCh chan<- *Subscription - err error + topic string + ch chan *Message + cancelCh chan<- *Subscription + inboundSubs chan peer.ID + err error } func (sub *Subscription) Topic() string { @@ -31,3 +33,16 @@ func (sub *Subscription) Next(ctx context.Context) (*Message, error) { func (sub *Subscription) Cancel() { sub.cancelCh <- sub } + +func (sub *Subscription) NextSubscribedPeer(ctx context.Context) (peer.ID, error) { + select { + case newPeer, ok := <-sub.inboundSubs: + if !ok { + return newPeer, sub.err + } + + return newPeer, nil + case <-ctx.Done(): + return "", ctx.Err() + } +}