mirror of
https://github.com/status-im/status-go.git
synced 2025-01-21 12:11:44 +00:00
Added PayloadManager and outbound pairing tests
This commit is contained in:
parent
366c088ec5
commit
566db2e3df
@ -1,23 +1,28 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
type PairingClient struct {
|
||||
*http.Client
|
||||
|
||||
baseAddress *url.URL
|
||||
certPEM []byte
|
||||
privateKey *ecdsa.PrivateKey
|
||||
aesKey []byte
|
||||
serverMode Mode
|
||||
payload *PayloadManager
|
||||
}
|
||||
|
||||
func NewClient(c *ConnectionParams) (*Client, error) {
|
||||
func NewPairingClient(c *ConnectionParams) (*PairingClient, error) {
|
||||
u, certPem, err := c.Generate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -40,10 +45,56 @@ func NewClient(c *ConnectionParams) (*Client, error) {
|
||||
},
|
||||
}
|
||||
|
||||
return &Client{
|
||||
ek, err := makeEncryptionKey(c.privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PairingClient{
|
||||
Client: &http.Client{Transport: tr},
|
||||
baseAddress: u,
|
||||
certPEM: certPem,
|
||||
privateKey: c.privateKey,
|
||||
aesKey: ek,
|
||||
serverMode: c.serverMode,
|
||||
payload: new(PayloadManager),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *PairingClient) MountPayload(data []byte) {
|
||||
s.payload.Mount(data)
|
||||
}
|
||||
|
||||
func (c *PairingClient) PairAccount() error {
|
||||
switch c.serverMode {
|
||||
case Receiving:
|
||||
return c.sendAccountData()
|
||||
case Sending:
|
||||
return c.receiveAccountData()
|
||||
default:
|
||||
return fmt.Errorf("unrecognised server mode '%d'", c.serverMode)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PairingClient) sendAccountData() error {
|
||||
c.baseAddress.Path = pairingReceive
|
||||
_, err := c.Post(c.baseAddress.String(), "application/octet-stream", bytes.NewBuffer(c.payload.ToSend()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PairingClient) receiveAccountData() error {
|
||||
c.baseAddress.Path = pairingSend
|
||||
resp, err := c.Get(c.baseAddress.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
content, _ := ioutil.ReadAll(resp.Body)
|
||||
c.payload.Receive(content)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -3,8 +3,11 @@ package server
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"math/big"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -70,3 +73,37 @@ func (tcc *TestCertComponents) SetupCertComponents(t *testing.T) {
|
||||
|
||||
tcc.NotAfter = tcc.NotBefore.Add(time.Hour)
|
||||
}
|
||||
|
||||
type TestPairingServerComponents struct {
|
||||
EphemeralPK *ecdsa.PrivateKey
|
||||
OutboundIP net.IP
|
||||
CertTime time.Time
|
||||
Cert tls.Certificate
|
||||
PS *PairingServer
|
||||
}
|
||||
|
||||
func (tpsc *TestPairingServerComponents) SetupPairingServerComponents(t *testing.T) {
|
||||
var err error
|
||||
|
||||
// Get 3 key components for tls.cert generation
|
||||
// 1) Ephemeral private key
|
||||
tpsc.EphemeralPK, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2) Device outbound IP address
|
||||
tpsc.OutboundIP, err = GetOutboundIP()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3) NotBefore time
|
||||
tpsc.CertTime = time.Now()
|
||||
|
||||
// Generate tls.Certificate and Server
|
||||
tpsc.Cert, _, err = GenerateCertFromKey(tpsc.EphemeralPK, tpsc.CertTime, tpsc.OutboundIP.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
tpsc.PS, err = NewPairingServer(&Config{
|
||||
PK: tpsc.EphemeralPK,
|
||||
Cert: &tpsc.Cert,
|
||||
Hostname: tpsc.OutboundIP.String()})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
11
server/encryption.go
Normal file
11
server/encryption.go
Normal file
@ -0,0 +1,11 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
|
||||
"github.com/status-im/status-go/protocol/common"
|
||||
)
|
||||
|
||||
func makeEncryptionKey(key *ecdsa.PrivateKey) ([]byte, error) {
|
||||
return common.MakeECDHSharedKey(key, &key.PublicKey)
|
||||
}
|
@ -1,7 +1,9 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@ -18,6 +20,11 @@ const (
|
||||
imagesPath = basePath + "/images"
|
||||
audioPath = basePath + "/audio"
|
||||
ipfsPath = "/ipfs"
|
||||
|
||||
// Handler routes for pairing
|
||||
pairingBase = "/pairing"
|
||||
pairingSend = pairingBase + "/send"
|
||||
pairingReceive = pairingBase + "/receive"
|
||||
)
|
||||
|
||||
type HandlerPatternMap map[string]http.HandlerFunc
|
||||
@ -133,3 +140,26 @@ func handleIPFS(downloader *ipfs.Downloader, logger *zap.Logger) func(w http.Res
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handlePairingReceive(ps *PairingServer) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
content, err := ioutil.ReadAll(r.Body)
|
||||
ps.logger.Error("ioutil.ReadAll(r.Body)", zap.Error(err))
|
||||
ps.payload.Receive(content)
|
||||
}
|
||||
}
|
||||
|
||||
func handlePairingSend(ps *PairingServer) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
|
||||
b := make([]byte, 32)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
w.Write([]byte(err.Error()))
|
||||
}
|
||||
|
||||
ps.payload.Mount(b)
|
||||
w.Write(b)
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,9 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
@ -13,6 +12,19 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetOutboundIPSuite(t *testing.T) {
|
||||
suite.Run(t, new(GetOutboundIPSuite))
|
||||
}
|
||||
|
||||
type GetOutboundIPSuite struct {
|
||||
suite.Suite
|
||||
TestPairingServerComponents
|
||||
}
|
||||
|
||||
func (s *GetOutboundIPSuite) SetupSuite() {
|
||||
s.SetupPairingServerComponents(s.T())
|
||||
}
|
||||
|
||||
func testHandler(t *testing.T) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
say, ok := r.URL.Query()["say"]
|
||||
@ -37,35 +49,17 @@ func makeThingToSay() (string, error) {
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func TestGetOutboundIPWithFullServerE2e(t *testing.T) {
|
||||
// Get 3 key components for tls.cert generation
|
||||
// 1) Ephemeral private key
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
func (goip *GetOutboundIPSuite) TestGetOutboundIPWithFullServerE2e(t *testing.T) {
|
||||
goip.PS.SetHandlers(HandlerPatternMap{"/hello": testHandler(t)})
|
||||
|
||||
// 2) Device outbound IP address
|
||||
ip, err := GetOutboundIP()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3) NotBefore time
|
||||
certTime := time.Now()
|
||||
|
||||
// Generate tls.Certificate and Server
|
||||
cert, _, err := GenerateCertFromKey(pk, certTime, ip.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
s := NewPairingServer(&Config{pk, &cert, ip.String(), Sending})
|
||||
|
||||
s.SetHandlers(HandlerPatternMap{"/hello": testHandler(t)})
|
||||
|
||||
err = s.Start()
|
||||
err := goip.PS.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Give time for the sever to be ready, hacky I know, I'll iron this out
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Server generates a QR code connection string
|
||||
cp, err := s.MakeConnectionParams()
|
||||
cp, err := goip.PS.MakeConnectionParams()
|
||||
require.NoError(t, err)
|
||||
|
||||
qr, err := cp.ToString()
|
||||
@ -76,7 +70,7 @@ func TestGetOutboundIPWithFullServerE2e(t *testing.T) {
|
||||
err = ccp.FromString(qr)
|
||||
require.NoError(t, err)
|
||||
|
||||
c, err := NewClient(ccp)
|
||||
c, err := NewPairingClient(ccp)
|
||||
require.NoError(t, err)
|
||||
|
||||
thing, err := makeThingToSay()
|
||||
|
22
server/payload_manager.go
Normal file
22
server/payload_manager.go
Normal file
@ -0,0 +1,22 @@
|
||||
package server
|
||||
|
||||
type PayloadManager struct {
|
||||
toSend []byte
|
||||
received []byte
|
||||
}
|
||||
|
||||
func (pm *PayloadManager) Mount(data []byte) {
|
||||
pm.toSend = data
|
||||
}
|
||||
|
||||
func (pm *PayloadManager) Receive(data []byte) {
|
||||
pm.received = data
|
||||
}
|
||||
|
||||
func (pm *PayloadManager) ToSend() []byte {
|
||||
return pm.toSend
|
||||
}
|
||||
|
||||
func (pm *PayloadManager) Received() []byte {
|
||||
return pm.received
|
||||
}
|
@ -10,8 +10,10 @@ import (
|
||||
type PairingServer struct {
|
||||
Server
|
||||
|
||||
pk *ecdsa.PrivateKey
|
||||
mode Mode
|
||||
pk *ecdsa.PrivateKey
|
||||
aesKey []byte
|
||||
mode Mode
|
||||
payload *PayloadManager
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@ -22,13 +24,20 @@ type Config struct {
|
||||
}
|
||||
|
||||
// NewPairingServer returns a *NewPairingServer init from the given *Config
|
||||
func NewPairingServer(config *Config) *PairingServer {
|
||||
func NewPairingServer(config *Config) (*PairingServer, error) {
|
||||
ek, err := makeEncryptionKey(config.PK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PairingServer{Server: NewServer(
|
||||
config.Cert,
|
||||
config.Hostname,
|
||||
),
|
||||
pk: config.PK,
|
||||
mode: config.Mode}
|
||||
pk: config.PK,
|
||||
aesKey: ek,
|
||||
mode: config.Mode,
|
||||
payload: new(PayloadManager)}, nil
|
||||
}
|
||||
|
||||
// MakeConnectionParams generates a *ConnectionParams based on the Server's current state
|
||||
@ -58,3 +67,28 @@ func (s *PairingServer) MakeConnectionParams() (*ConnectionParams, error) {
|
||||
|
||||
return NewConnectionParams(netIP, s.port, s.pk, s.cert.Leaf.NotBefore, s.mode), nil
|
||||
}
|
||||
|
||||
func (s *PairingServer) MountPayload(data []byte) {
|
||||
s.payload.Mount(data)
|
||||
}
|
||||
|
||||
func (s *PairingServer) StartPairing() error {
|
||||
switch s.mode {
|
||||
case Receiving:
|
||||
return s.startReceivingAccountData()
|
||||
case Sending:
|
||||
return s.startSendingAccountData()
|
||||
default:
|
||||
return fmt.Errorf("invalid server mode '%d'", s.mode)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PairingServer) startReceivingAccountData() error {
|
||||
s.SetHandlers(HandlerPatternMap{pairingReceive: handlePairingReceive(s)})
|
||||
return s.Start()
|
||||
}
|
||||
|
||||
func (s *PairingServer) startSendingAccountData() error {
|
||||
s.SetHandlers(HandlerPatternMap{pairingSend: handlePairingSend(s)})
|
||||
return s.Start()
|
||||
}
|
||||
|
75
server/server_pairing_test.go
Normal file
75
server/server_pairing_test.go
Normal file
@ -0,0 +1,75 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
func TestPairingServerSuite(t *testing.T) {
|
||||
suite.Run(t, new(PairingServerSuite))
|
||||
}
|
||||
|
||||
type PairingServerSuite struct {
|
||||
suite.Suite
|
||||
TestPairingServerComponents
|
||||
}
|
||||
|
||||
func (s *PairingServerSuite) SetupSuite() {
|
||||
s.SetupPairingServerComponents(s.T())
|
||||
}
|
||||
|
||||
func (s *PairingServerSuite) TestPairingServer_StartPairing() {
|
||||
modes := []Mode{
|
||||
Receiving,
|
||||
Sending,
|
||||
}
|
||||
|
||||
for _, m := range modes {
|
||||
s.PS.mode = m
|
||||
|
||||
// Random payload
|
||||
data := make([]byte, 32)
|
||||
_, err := rand.Read(data)
|
||||
s.Require().NoError(err)
|
||||
|
||||
if m == Sending {
|
||||
s.PS.MountPayload(data)
|
||||
}
|
||||
|
||||
err = s.PS.StartPairing()
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Give time for the sever to be ready, hacky I know, I'll iron this out
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
cp, err := s.PS.MakeConnectionParams()
|
||||
s.Require().NoError(err)
|
||||
|
||||
qr, err := cp.ToString()
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Client reads QR code and parses the connection string
|
||||
ccp := new(ConnectionParams)
|
||||
err = ccp.FromString(qr)
|
||||
s.Require().NoError(err)
|
||||
|
||||
c, err := NewPairingClient(ccp)
|
||||
s.Require().NoError(err)
|
||||
|
||||
if m == Receiving {
|
||||
c.MountPayload(data)
|
||||
}
|
||||
|
||||
err = c.PairAccount()
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().Equal(s.PS.payload.ToSend(), c.payload.Received())
|
||||
s.Require().Equal(s.PS.payload.Received(), c.payload.ToSend())
|
||||
|
||||
// Reset the server's PayloadManager
|
||||
s.PS.payload = new(PayloadManager)
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user