diff --git a/waku/node.go b/waku/node.go index 031e27b6..a2e7b1b9 100644 --- a/waku/node.go +++ b/waku/node.go @@ -370,7 +370,7 @@ func Execute(options Options) { if options.RESTServer.Enable { wg.Add(1) restServer = rest.NewWakuRest(wakuNode, options.RESTServer.Address, options.RESTServer.Port, options.RESTServer.Admin, options.RESTServer.Private, options.RESTServer.RelayCacheCapacity, logger) - restServer.Start(&wg) + restServer.Start(ctx, &wg) } wg.Wait() diff --git a/waku/v2/rest/relay.go b/waku/v2/rest/relay.go index 37babb52..d99b7296 100644 --- a/waku/v2/rest/relay.go +++ b/waku/v2/rest/relay.go @@ -1,6 +1,7 @@ package rest import ( + "context" "encoding/json" "net/http" "strings" @@ -18,8 +19,9 @@ const ROUTE_RELAY_SUBSCRIPTIONSV1 = "/relay/v1/subscriptions" const ROUTE_RELAY_MESSAGESV1 = "/relay/v1/messages/{topic}" type RelayService struct { - node *node.WakuNode - mux *mux.Router + node *node.WakuNode + mux *mux.Router + cancel context.CancelFunc log *zap.Logger @@ -65,18 +67,21 @@ func (r *RelayService) addEnvelope(envelope *protocol.Envelope) { r.messages[envelope.PubsubTopic()] = append(r.messages[envelope.PubsubTopic()], envelope.Message()) } -func (r *RelayService) Start() { +func (r *RelayService) Start(ctx context.Context) { + ctx, cancel := context.WithCancel(ctx) + r.cancel = cancel + // Node may already be subscribed to some topics when Relay API handlers are installed. Let's add these for _, topic := range r.node.Relay().Topics() { r.log.Info("adding topic handler for existing subscription", zap.String("topic", topic)) r.messages[topic] = []*pb.WakuMessage{} } - r.runner.Start() + r.runner.Start(ctx) } func (r *RelayService) Stop() { - r.runner.Stop() + r.cancel() } func (d *RelayService) deleteV1Subscriptions(w http.ResponseWriter, r *http.Request) { diff --git a/waku/v2/rest/relay_test.go b/waku/v2/rest/relay_test.go index bc78017c..4181420e 100644 --- a/waku/v2/rest/relay_test.go +++ b/waku/v2/rest/relay_test.go @@ -54,7 +54,7 @@ func TestPostV1Message(t *testing.T) { func TestRelaySubscription(t *testing.T) { d := makeRelayService(t) - go d.Start() + go d.Start(context.Background()) defer d.Stop() topics := []string{"test"} @@ -96,10 +96,10 @@ func TestRelaySubscription(t *testing.T) { func TestRelayGetV1Messages(t *testing.T) { serviceA := makeRelayService(t) - go serviceA.Start() + go serviceA.Start(context.Background()) defer serviceA.Stop() serviceB := makeRelayService(t) - go serviceB.Start() + go serviceB.Start(context.Background()) defer serviceB.Stop() hostInfo, err := multiaddr.NewMultiaddr(fmt.Sprintf("/p2p/%s", serviceB.node.Host().ID().Pretty())) diff --git a/waku/v2/rest/runner.go b/waku/v2/rest/runner.go index f6baf075..1e1b5e92 100644 --- a/waku/v2/rest/runner.go +++ b/waku/v2/rest/runner.go @@ -1,6 +1,8 @@ package rest import ( + "context" + v2 "github.com/waku-org/go-waku/waku/v2" "github.com/waku-org/go-waku/waku/v2/protocol" ) @@ -10,24 +12,25 @@ type Adder func(msg *protocol.Envelope) type runnerService struct { broadcaster v2.Broadcaster ch chan *protocol.Envelope - quit chan bool + cancel context.CancelFunc adder Adder } func newRunnerService(broadcaster v2.Broadcaster, adder Adder) *runnerService { return &runnerService{ broadcaster: broadcaster, - quit: make(chan bool), adder: adder, } } -func (r *runnerService) Start() { +func (r *runnerService) Start(ctx context.Context) { + ctx, cancel := context.WithCancel(ctx) r.ch = make(chan *protocol.Envelope, 1024) + r.cancel = cancel r.broadcaster.Register(nil, r.ch) for { select { - case <-r.quit: + case <-ctx.Done(): return case envelope := <-r.ch: r.adder(envelope) @@ -36,7 +39,7 @@ func (r *runnerService) Start() { } func (r *runnerService) Stop() { - r.quit <- true + r.cancel() r.broadcaster.Unregister(nil, r.ch) close(r.ch) } diff --git a/waku/v2/rest/waku_rest.go b/waku/v2/rest/waku_rest.go index f9f323bd..b9a607bf 100644 --- a/waku/v2/rest/waku_rest.go +++ b/waku/v2/rest/waku_rest.go @@ -47,9 +47,9 @@ func NewWakuRest(node *node.WakuNode, address string, port int, enableAdmin bool return wrpc } -func (r *WakuRest) Start(wg *sync.WaitGroup) { +func (r *WakuRest) Start(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - go r.relayService.Start() + go r.relayService.Start(ctx) go func() { _ = r.server.ListenAndServe() }()