diff --git a/waku/v2/broadcast.go b/waku/v2/broadcast.go index a737252e..0c8c12f3 100644 --- a/waku/v2/broadcast.go +++ b/waku/v2/broadcast.go @@ -1,6 +1,9 @@ package v2 import ( + "context" + "errors" + "github.com/waku-org/go-waku/waku/v2/protocol" ) @@ -18,6 +21,9 @@ type chOperation struct { type broadcastOutputs map[chan<- *protocol.Envelope]struct{} type broadcaster struct { + bufLen int + cancel context.CancelFunc + input chan *protocol.Envelope reg chan chOperation unreg chan chOperation @@ -37,8 +43,10 @@ type Broadcaster interface { Unregister(topic *string, newch chan<- *protocol.Envelope) // Unregister a subscriptor channel and return a channel to wait until this operation is done WaitUnregister(topic *string, newch chan<- *protocol.Envelope) doneCh + // Start + Start(ctx context.Context) error // Shut this broadcaster down. - Close() + Stop() // Submit a new object to all subscribers Submit(*protocol.Envelope) } @@ -58,11 +66,15 @@ func (b *broadcaster) broadcast(m *protocol.Envelope) { } } -func (b *broadcaster) run() { +func (b *broadcaster) run(ctx context.Context) { for { select { - case m := <-b.input: - b.broadcast(m) + case <-ctx.Done(): + return + case m, ok := <-b.input: + if ok { + b.broadcast(m) + } case broadcastee, ok := <-b.reg: if ok { if broadcastee.topic != nil { @@ -109,17 +121,43 @@ func (b *broadcaster) run() { // It's used to register subscriptors that will need to receive // an Envelope containing a WakuMessage func NewBroadcaster(buflen int) Broadcaster { - b := &broadcaster{ - input: make(chan *protocol.Envelope, buflen), - reg: make(chan chOperation), - unreg: make(chan chOperation), - outputs: make(broadcastOutputs), - outputsPerTopic: make(map[string]broadcastOutputs), + return &broadcaster{ + bufLen: buflen, + } +} + +func (b *broadcaster) Start(ctx context.Context) error { + if b.cancel != nil { + return errors.New("already started") } - go b.run() + ctx, cancel := context.WithCancel(ctx) - return b + b.cancel = cancel + b.input = make(chan *protocol.Envelope, b.bufLen) + b.reg = make(chan chOperation) + b.unreg = make(chan chOperation) + b.outputs = make(broadcastOutputs) + b.outputsPerTopic = make(map[string]broadcastOutputs) + + go b.run(ctx) + + return nil +} + +func (b *broadcaster) Stop() { + if b.cancel != nil { + return + } + + b.cancel() + + close(b.input) + close(b.reg) + close(b.unreg) + b.outputs = nil + b.outputsPerTopic = nil + b.cancel = nil } // Register a subscriptor channel and return a channel to wait until this operation is done diff --git a/waku/v2/broadcast_test.go b/waku/v2/broadcast_test.go index 24027d86..367a6624 100644 --- a/waku/v2/broadcast_test.go +++ b/waku/v2/broadcast_test.go @@ -1,9 +1,11 @@ package v2 import ( + "context" "sync" "testing" + "github.com/stretchr/testify/require" "github.com/waku-org/go-waku/waku/v2/protocol" "github.com/waku-org/go-waku/waku/v2/protocol/pb" "github.com/waku-org/go-waku/waku/v2/utils" @@ -15,7 +17,8 @@ func TestBroadcast(t *testing.T) { wg := sync.WaitGroup{} b := NewBroadcaster(100) - defer b.Close() + require.NoError(t, b.Start(context.Background())) + defer b.Stop() for i := 0; i < 5; i++ { wg.Add(1) @@ -40,7 +43,8 @@ func TestBroadcastWait(t *testing.T) { wg := sync.WaitGroup{} b := NewBroadcaster(100) - defer b.Close() + require.NoError(t, b.Start(context.Background())) + defer b.Stop() for i := 0; i < 5; i++ { wg.Add(1) @@ -65,7 +69,8 @@ func TestBroadcastWait(t *testing.T) { func TestBroadcastCleanup(t *testing.T) { b := NewBroadcaster(100) + require.NoError(t, b.Start(context.Background())) topic := "test" b.Register(&topic, make(chan *protocol.Envelope)) - b.Close() + b.Stop() } diff --git a/waku/v2/node/wakunode2.go b/waku/v2/node/wakunode2.go index 83b917ad..44af2116 100644 --- a/waku/v2/node/wakunode2.go +++ b/waku/v2/node/wakunode2.go @@ -306,12 +306,17 @@ func (w *WakuNode) Start(ctx context.Context) error { go w.watchMultiaddressChanges(ctx) go w.watchENRChanges(ctx) + err := w.bcaster.Start(ctx) + if err != nil { + return err + } + if w.opts.keepAliveInterval > time.Duration(0) { w.wg.Add(1) go w.startKeepAlive(ctx, w.opts.keepAliveInterval) } - err := w.peerConnector.Start(ctx) + err = w.peerConnector.Start(ctx) if err != nil { return err } @@ -420,7 +425,7 @@ func (w *WakuNode) Stop() { w.cancel() - w.bcaster.Close() + w.bcaster.Stop() defer w.connectionNotif.Close() defer w.protocolEventSub.Close() @@ -453,6 +458,8 @@ func (w *WakuNode) Stop() { w.wg.Wait() close(w.enrChangeCh) + + w.cancel = nil } // Host returns the libp2p Host used by the WakuNode diff --git a/waku/v2/protocol/filter/filter_test.go b/waku/v2/protocol/filter/filter_test.go index ac2960d1..adc95c0b 100644 --- a/waku/v2/protocol/filter/filter_test.go +++ b/waku/v2/protocol/filter/filter_test.go @@ -42,7 +42,9 @@ func makeWakuFilterLightNode(t *testing.T) (*WakuFilterLightnode, host.Host) { host, err := tests.MakeHost(context.Background(), port, rand.Reader) require.NoError(t, err) - filterPush := NewWakuFilterLightnode(host, v2.NewBroadcaster(10), timesource.NewDefaultClock(), utils.Logger()) + b := v2.NewBroadcaster(10) + require.NoError(t, b.Start(context.Background())) + filterPush := NewWakuFilterLightnode(host, b, timesource.NewDefaultClock(), utils.Logger()) err = filterPush.Start(context.Background()) require.NoError(t, err) @@ -70,6 +72,7 @@ func TestWakuFilter(t *testing.T) { defer node1.Stop() broadcaster := v2.NewBroadcaster(10) + require.NoError(t, broadcaster.Start(context.Background())) node2, sub2, host2 := makeWakuRelay(t, testTopic, broadcaster) defer node2.Stop() defer sub2.Unsubscribe() @@ -158,6 +161,7 @@ func TestSubscriptionPing(t *testing.T) { defer node1.Stop() broadcaster := v2.NewBroadcaster(10) + require.NoError(t, broadcaster.Start(context.Background())) node2, sub2, host2 := makeWakuRelay(t, testTopic, broadcaster) defer node2.Stop() defer sub2.Unsubscribe() @@ -197,11 +201,14 @@ func TestWakuFilterPeerFailure(t *testing.T) { node1, host1 := makeWakuFilterLightNode(t) broadcaster := v2.NewBroadcaster(10) + require.NoError(t, broadcaster.Start(context.Background())) node2, sub2, host2 := makeWakuRelay(t, testTopic, broadcaster) defer node2.Stop() defer sub2.Unsubscribe() - node2Filter := NewWakuFilterFullnode(host2, v2.NewBroadcaster(10), timesource.NewDefaultClock(), utils.Logger(), WithTimeout(5*time.Second)) + broadcaster2 := v2.NewBroadcaster(10) + require.NoError(t, broadcaster2.Start(context.Background())) + node2Filter := NewWakuFilterFullnode(host2, broadcaster2, timesource.NewDefaultClock(), utils.Logger(), WithTimeout(5*time.Second)) err := node2Filter.Start(ctx) require.NoError(t, err) diff --git a/waku/v2/protocol/legacy_filter/filter_map_test.go b/waku/v2/protocol/legacy_filter/filter_map_test.go index afc1066d..d369ab8a 100644 --- a/waku/v2/protocol/legacy_filter/filter_map_test.go +++ b/waku/v2/protocol/legacy_filter/filter_map_test.go @@ -1,6 +1,7 @@ package legacy_filter import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -10,7 +11,9 @@ import ( ) func TestFilterMap(t *testing.T) { - fmap := NewFilterMap(v2.NewBroadcaster(100), timesource.NewDefaultClock()) + b := v2.NewBroadcaster(100) + require.NoError(t, b.Start(context.Background())) + fmap := NewFilterMap(b, timesource.NewDefaultClock()) filter := Filter{ PeerID: "id", diff --git a/waku/v2/protocol/legacy_filter/waku_filter_test.go b/waku/v2/protocol/legacy_filter/waku_filter_test.go index e9baa977..dc1cd9df 100644 --- a/waku/v2/protocol/legacy_filter/waku_filter_test.go +++ b/waku/v2/protocol/legacy_filter/waku_filter_test.go @@ -41,7 +41,9 @@ func makeWakuFilter(t *testing.T) (*WakuFilter, host.Host) { host, err := tests.MakeHost(context.Background(), port, rand.Reader) require.NoError(t, err) - filter := NewWakuFilter(host, v2.NewBroadcaster(10), false, timesource.NewDefaultClock(), utils.Logger()) + b := v2.NewBroadcaster(10) + require.NoError(t, b.Start(context.Background())) + filter := NewWakuFilter(host, b, false, timesource.NewDefaultClock(), utils.Logger()) err = filter.Start(context.Background()) require.NoError(t, err) @@ -69,6 +71,7 @@ func TestWakuFilter(t *testing.T) { defer node1.Stop() broadcaster := v2.NewBroadcaster(10) + require.NoError(t, broadcaster.Start(context.Background())) node2, sub2, host2 := makeWakuRelay(t, testTopic, broadcaster) defer node2.Stop() defer sub2.Unsubscribe() @@ -157,11 +160,14 @@ func TestWakuFilterPeerFailure(t *testing.T) { node1, host1 := makeWakuFilter(t) broadcaster := v2.NewBroadcaster(10) + require.NoError(t, broadcaster.Start(context.Background())) node2, sub2, host2 := makeWakuRelay(t, testTopic, broadcaster) defer node2.Stop() defer sub2.Unsubscribe() - node2Filter := NewWakuFilter(host2, v2.NewBroadcaster(10), true, timesource.NewDefaultClock(), utils.Logger(), WithTimeout(3*time.Second)) + broadcaster2 := v2.NewBroadcaster(10) + require.NoError(t, broadcaster2.Start(context.Background())) + node2Filter := NewWakuFilter(host2, broadcaster2, true, timesource.NewDefaultClock(), utils.Logger(), WithTimeout(3*time.Second)) err := node2Filter.Start(ctx) require.NoError(t, err) diff --git a/waku/v2/protocol/lightpush/waku_lightpush_test.go b/waku/v2/protocol/lightpush/waku_lightpush_test.go index dc6da6bf..ce7adece 100644 --- a/waku/v2/protocol/lightpush/waku_lightpush_test.go +++ b/waku/v2/protocol/lightpush/waku_lightpush_test.go @@ -26,7 +26,9 @@ func makeWakuRelay(t *testing.T, topic string) (*relay.WakuRelay, *relay.Subscri host, err := tests.MakeHost(context.Background(), port, rand.Reader) require.NoError(t, err) - relay := relay.NewWakuRelay(host, v2.NewBroadcaster(10), 0, timesource.NewDefaultClock(), utils.Logger()) + b := v2.NewBroadcaster(10) + require.NoError(t, b.Start(context.Background())) + relay := relay.NewWakuRelay(host, b, 0, timesource.NewDefaultClock(), utils.Logger()) require.NoError(t, err) err = relay.Start(context.Background()) require.NoError(t, err) diff --git a/waku/v2/protocol/noise/pairing_relay_messenger.go b/waku/v2/protocol/noise/pairing_relay_messenger.go index 1440653f..281ce445 100644 --- a/waku/v2/protocol/noise/pairing_relay_messenger.go +++ b/waku/v2/protocol/noise/pairing_relay_messenger.go @@ -58,12 +58,17 @@ func NewWakuRelayMessenger(ctx context.Context, r *relay.WakuRelay, pubsubTopic subscriptionChPerContentTopic: make(map[string][]contentTopicSubscription), } + err = wr.broadcaster.Start(ctx) + if err != nil { + return nil, err + } + go func() { for { select { case <-ctx.Done(): subs.Unsubscribe() - wr.broadcaster.Close() + wr.broadcaster.Stop() return case envelope := <-subs.C: if envelope != nil { diff --git a/waku/v2/protocol/noise/pairing_test.go b/waku/v2/protocol/noise/pairing_test.go index f99569e4..406f2687 100644 --- a/waku/v2/protocol/noise/pairing_test.go +++ b/waku/v2/protocol/noise/pairing_test.go @@ -26,7 +26,9 @@ func createRelayNode(t *testing.T) (host.Host, *relay.WakuRelay) { host, err := tests.MakeHost(context.Background(), port, rand.Reader) require.NoError(t, err) - relay := relay.NewWakuRelay(host, v2.NewBroadcaster(1024), 0, timesource.NewDefaultClock(), utils.Logger()) + b := v2.NewBroadcaster(1024) + require.NoError(t, b.Start(context.Background())) + relay := relay.NewWakuRelay(host, b, 0, timesource.NewDefaultClock(), utils.Logger()) err = relay.Start(context.Background()) require.NoError(t, err) diff --git a/waku/v2/protocol/relay/validators.go b/waku/v2/protocol/relay/validators.go index 30b1a826..97ceb444 100644 --- a/waku/v2/protocol/relay/validators.go +++ b/waku/v2/protocol/relay/validators.go @@ -21,9 +21,9 @@ func MsgHash(pubSubTopic string, msg *pb.WakuMessage) []byte { return hash.SHA256([]byte(pubSubTopic), msg.Payload, []byte(msg.ContentTopic)) } -func (w *WakuRelay) AddSignedTopicValidator(topic string, publicKey *ecdsa.PublicKey) { +func (w *WakuRelay) AddSignedTopicValidator(topic string, publicKey *ecdsa.PublicKey) error { w.log.Info("adding validator to signed topic", zap.String("topic", topic), zap.String("publicKey", hex.EncodeToString(elliptic.Marshal(publicKey.Curve, publicKey.X, publicKey.Y)))) - w.pubsub.RegisterTopicValidator(topic, func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { + err := w.pubsub.RegisterTopicValidator(topic, func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { msg := new(pb.WakuMessage) err := proto.Unmarshal(message.Data, msg) if err != nil { @@ -35,6 +35,7 @@ func (w *WakuRelay) AddSignedTopicValidator(topic string, publicKey *ecdsa.Publi return ecdsa.VerifyASN1(publicKey, msgHash, signature) }) + return err } func (w *WakuRelay) SignMessage(privKey *ecdsa.PrivateKey, topic string, msg *pb.WakuMessage) error { diff --git a/waku/v2/rpc/filter_test.go b/waku/v2/rpc/filter_test.go index d720a73a..3469b538 100644 --- a/waku/v2/rpc/filter_test.go +++ b/waku/v2/rpc/filter_test.go @@ -50,14 +50,18 @@ func TestFilterSubscription(t *testing.T) { host, err := tests.MakeHost(context.Background(), port, rand.Reader) require.NoError(t, err) - node := relay.NewWakuRelay(host, v2.NewBroadcaster(10), 0, timesource.NewDefaultClock(), utils.Logger()) + b := v2.NewBroadcaster(10) + require.NoError(t, b.Start(context.Background())) + node := relay.NewWakuRelay(host, b, 0, timesource.NewDefaultClock(), utils.Logger()) err = node.Start(context.Background()) require.NoError(t, err) _, err = node.SubscribeToTopic(context.Background(), testTopic) require.NoError(t, err) - f := legacy_filter.NewWakuFilter(host, v2.NewBroadcaster(10), false, timesource.NewDefaultClock(), utils.Logger()) + b2 := v2.NewBroadcaster(10) + require.NoError(t, b2.Start(context.Background())) + f := legacy_filter.NewWakuFilter(host, b2, false, timesource.NewDefaultClock(), utils.Logger()) err = f.Start(context.Background()) require.NoError(t, err)