diff --git a/cmd/statusd/main.go b/cmd/statusd/main.go index 6bc248dd4..d22ace42f 100644 --- a/cmd/statusd/main.go +++ b/cmd/statusd/main.go @@ -23,6 +23,7 @@ import ( "github.com/status-im/status-go/api" "github.com/status-im/status-go/appdatabase" + "github.com/status-im/status-go/cmd/statusd/server" "github.com/status-im/status-go/common/dbsetup" gethbridge "github.com/status-im/status-go/eth-node/bridge/geth" "github.com/status-im/status-go/eth-node/crypto" @@ -81,6 +82,7 @@ var ( ), ) listenAddr = flag.String("addr", "", "address to bind listener to") + serverAddr = flag.String("server", "", "Address `host:port` for HTTP API server of statusd") // don't change the name of this flag, https://github.com/ethereum/go-ethereum/blob/master/metrics/metrics.go#L41 metricsEnabled = flag.Bool("metrics", false, "Expose ethereum metrics with debug_metrics jsonrpc call") @@ -168,6 +170,22 @@ func main() { return } + if serverAddr != nil && *serverAddr != "" { + srv := server.NewServer() + srv.Setup() + err = srv.Listen(*serverAddr) + if err != nil { + logger.Error("failed to start server", "error", err) + return + } + log.Info("server started", "address", srv.Address()) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + srv.Stop(ctx) + }() + } + backend := api.NewGethStatusBackend() if config.NodeKey == "" { logger.Error("node key needs to be set if running a push notification server") diff --git a/cmd/statusd/server/signals_server.go b/cmd/statusd/server/signals_server.go new file mode 100644 index 000000000..539a834f3 --- /dev/null +++ b/cmd/statusd/server/signals_server.go @@ -0,0 +1,115 @@ +package server + +import ( + "context" + "errors" + "net" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" + + "github.com/ethereum/go-ethereum/log" + + "github.com/status-im/status-go/signal" +) + +type Server struct { + server *http.Server + lock sync.Mutex + connections map[*websocket.Conn]struct{} + address string +} + +func NewServer() *Server { + return &Server{ + connections: make(map[*websocket.Conn]struct{}, 1), + } +} + +func (s *Server) Address() string { + return s.address +} + +func (s *Server) Setup() { + signal.SetMobileSignalHandler(s.signalHandler) +} + +func (s *Server) signalHandler(data []byte) { + s.lock.Lock() + defer s.lock.Unlock() + + for connection := range s.connections { + err := connection.WriteMessage(websocket.TextMessage, data) + if err != nil { + log.Error("failed to write message: %w", err) + } + } +} + +func (s *Server) Listen(address string) error { + if s.server != nil { + return errors.New("server already started") + } + + s.server = &http.Server{ + Addr: address, + ReadHeaderTimeout: 5 * time.Second, + } + + http.HandleFunc("/signals", s.signals) + + listener, err := net.Listen("tcp", address) + if err != nil { + return err + } + + s.address = listener.Addr().String() + + go func() { + err := s.server.Serve(listener) + if !errors.Is(err, http.ErrServerClosed) { + log.Error("signals server closed with error: %w", err) + } + }() + + return nil +} + +func (s *Server) Stop(ctx context.Context) { + for connection := range s.connections { + err := connection.Close() + if err != nil { + log.Error("failed to close connection: %w", err) + } + delete(s.connections, connection) + } + + err := s.server.Shutdown(ctx) + if err != nil { + log.Error("failed to shutdown signals server: %w", err) + } + + s.server = nil + s.address = "" +} + +func (s *Server) signals(w http.ResponseWriter, r *http.Request) { + s.lock.Lock() + defer s.lock.Unlock() + + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Accepting all requests + }, + } + + connection, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Error("failed to upgrade connection: %w", err) + return + } + + s.connections[connection] = struct{}{} +} diff --git a/cmd/statusd/server/signals_server_test.go b/cmd/statusd/server/signals_server_test.go new file mode 100644 index 000000000..fbc5ff784 --- /dev/null +++ b/cmd/statusd/server/signals_server_test.go @@ -0,0 +1,75 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/gorilla/websocket" + + "github.com/status-im/status-go/protocol/common" + "github.com/status-im/status-go/signal" +) + +func TestSignalsServer(t *testing.T) { + server := NewServer() + server.Setup() + err := server.Listen("localhost:0") + require.NoError(t, err) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + server.Stop(ctx) + }() + + addr := server.Address() + serverURLString := fmt.Sprintf("ws://%s", addr) + serverURL, err := url.Parse(serverURLString) + require.NoError(t, err) + require.NotZero(t, serverURL.Port()) + + connection, _, err := websocket.DefaultDialer.Dial(serverURLString+"/signals", nil) + require.NoError(t, err) + require.NotNil(t, connection) + defer func() { + err := connection.Close() + require.NoError(t, err) + }() + + sentEvent := signal.MessageDeliveredSignal{ + ChatID: randomAlphabeticalString(t, 10), + MessageID: randomAlphabeticalString(t, 10), + } + + signal.SendMessageDelivered(sentEvent.ChatID, sentEvent.MessageID) + + messageType, data, err := connection.ReadMessage() + require.NoError(t, err) + require.Equal(t, websocket.TextMessage, messageType) + + receivedSignal := signal.Envelope{} + err = json.Unmarshal(data, &receivedSignal) + require.NoError(t, err) + require.Equal(t, signal.EventMesssageDelivered, receivedSignal.Type) + require.NotNil(t, receivedSignal.Event) + + // Convert `interface{}` to json and then back to the original struct + tempJson, err := json.Marshal(receivedSignal.Event) + require.NoError(t, err) + + receivedEvent := signal.MessageDeliveredSignal{} + err = json.Unmarshal(tempJson, &receivedEvent) + require.NoError(t, err) + require.Equal(t, sentEvent, receivedEvent) +} + +func randomAlphabeticalString(t *testing.T, n int) string { + s, err := common.RandomAlphabeticalString(n) + require.NoError(t, err) + return s +} diff --git a/go.mod b/go.mod index f3cdbb5fa..c5a4a5e90 100644 --- a/go.mod +++ b/go.mod @@ -85,6 +85,7 @@ require ( github.com/bits-and-blooms/bloom/v3 v3.7.0 github.com/cenkalti/backoff/v4 v4.2.1 github.com/gorilla/sessions v1.2.1 + github.com/gorilla/websocket v1.5.3 github.com/ipfs/go-log/v2 v2.5.1 github.com/jellydator/ttlcache/v3 v3.2.0 github.com/jmoiron/sqlx v1.3.5 @@ -171,7 +172,6 @@ require ( github.com/google/gopacket v1.1.19 // indirect github.com/google/pprof v0.0.0-20240207164012-fb44976bdcd5 // indirect github.com/gorilla/securecookie v1.1.1 // indirect - github.com/gorilla/websocket v1.5.3 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-bexpr v0.1.10 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect