166 lines
4.3 KiB
Go
166 lines
4.3 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
"github.com/status-im/status-go/protocol/common"
|
|
"github.com/status-im/status-go/signal"
|
|
)
|
|
|
|
func setupServer(t *testing.T) (*Server, string) {
|
|
srv := NewServer()
|
|
srv.Setup()
|
|
err := srv.Listen("localhost:0")
|
|
require.NoError(t, err)
|
|
|
|
addr := srv.Address()
|
|
|
|
// Check URL
|
|
serverURLString := fmt.Sprintf("http://%s", addr)
|
|
serverURL, err := url.Parse(serverURLString)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, serverURL)
|
|
require.NotZero(t, serverURL.Port())
|
|
|
|
return srv, addr
|
|
}
|
|
|
|
func shutdownServer(srv *Server) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
srv.Stop(ctx)
|
|
}
|
|
|
|
func TestSignals(t *testing.T) {
|
|
srv, serverURLString := setupServer(t)
|
|
go srv.Serve()
|
|
defer shutdownServer(srv)
|
|
|
|
signalsURL := fmt.Sprintf("ws://%s/signals", serverURLString)
|
|
connection, _, err := websocket.DefaultDialer.Dial(signalsURL, nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, connection)
|
|
defer func() {
|
|
err := connection.Close()
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
sentEvent := signal.MessageDeliveredSignal{
|
|
ChatID: randomAlphabeticalString(t, 10),
|
|
MessageID: randomAlphabeticalString(t, 10),
|
|
}
|
|
|
|
signal.SendMessageDelivered(sentEvent.ChatID, sentEvent.MessageID)
|
|
|
|
messageType, data, err := connection.ReadMessage()
|
|
require.NoError(t, err)
|
|
require.Equal(t, websocket.TextMessage, messageType)
|
|
|
|
receivedSignal := signal.Envelope{}
|
|
err = json.Unmarshal(data, &receivedSignal)
|
|
require.NoError(t, err)
|
|
require.Equal(t, signal.EventMesssageDelivered, receivedSignal.Type)
|
|
require.NotNil(t, receivedSignal.Event)
|
|
|
|
// Convert `interface{}` to json and then back to the original struct
|
|
tempJson, err := json.Marshal(receivedSignal.Event)
|
|
require.NoError(t, err)
|
|
|
|
receivedEvent := signal.MessageDeliveredSignal{}
|
|
err = json.Unmarshal(tempJson, &receivedEvent)
|
|
require.NoError(t, err)
|
|
require.Equal(t, sentEvent, receivedEvent)
|
|
}
|
|
|
|
func TestMobileAPI(t *testing.T) {
|
|
// Setup fake endpoints
|
|
endpointsWithResponse := EndpointsWithRequest
|
|
endpointsNoRequest := EndpointsWithoutRequest
|
|
endpointsUnsupported := EndpointsUnsupported
|
|
t.Cleanup(func() {
|
|
EndpointsWithRequest = endpointsWithResponse
|
|
EndpointsWithoutRequest = endpointsNoRequest
|
|
EndpointsUnsupported = endpointsUnsupported
|
|
})
|
|
|
|
endpointWithResponse := "/" + randomAlphabeticalString(t, 5)
|
|
endpointNoRequest := "/" + randomAlphabeticalString(t, 5)
|
|
endpointUnsupported := "/" + randomAlphabeticalString(t, 5)
|
|
|
|
request1 := randomAlphabeticalString(t, 5)
|
|
response1 := randomAlphabeticalString(t, 5)
|
|
response2 := randomAlphabeticalString(t, 5)
|
|
|
|
EndpointsWithRequest = map[string]func(string) string{
|
|
endpointWithResponse: func(request string) string {
|
|
require.Equal(t, request1, request)
|
|
return response1
|
|
},
|
|
}
|
|
EndpointsWithoutRequest = map[string]func() string{
|
|
endpointNoRequest: func() string {
|
|
return response2
|
|
},
|
|
}
|
|
EndpointsUnsupported = []string{endpointUnsupported}
|
|
|
|
// Setup server
|
|
srv, _ := setupServer(t)
|
|
defer shutdownServer(srv)
|
|
go srv.Serve()
|
|
srv.RegisterMobileAPI()
|
|
|
|
requestBody := []byte(request1)
|
|
bodyReader := bytes.NewReader(requestBody)
|
|
|
|
port, err := srv.Port()
|
|
require.NoError(t, err)
|
|
|
|
serverURL := fmt.Sprintf("http://127.0.0.1:%d", port)
|
|
|
|
// Test endpoints with response
|
|
resp, err := http.Post(serverURL+endpointWithResponse, "application/text", bodyReader)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resp)
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
require.Equal(t, response1, string(responseBody))
|
|
|
|
// Test endpoints with no request
|
|
resp, err = http.Get(serverURL + endpointNoRequest)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resp)
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
responseBody, err = io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
require.Equal(t, response2, string(responseBody))
|
|
|
|
// Test unsupported endpoint
|
|
resp, err = http.Get(serverURL + endpointUnsupported)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resp)
|
|
require.Equal(t, http.StatusNotImplemented, resp.StatusCode)
|
|
|
|
}
|
|
|
|
func randomAlphabeticalString(t *testing.T, n int) string {
|
|
s, err := common.RandomAlphabeticalString(n)
|
|
require.NoError(t, err)
|
|
return s
|
|
}
|