From 566db2e3df1d3e9d2f688f6474853c1e7b5f8551 Mon Sep 17 00:00:00 2001 From: Samuel Hawksby-Robinson Date: Fri, 10 Jun 2022 16:32:15 +0100 Subject: [PATCH] Added PayloadManager and outbound pairing tests --- server/client.go | 57 ++++++++++++++++++++++++-- server/components_test.go | 37 +++++++++++++++++ server/encryption.go | 11 +++++ server/handlers.go | 30 ++++++++++++++ server/ips_test.go | 44 +++++++++----------- server/payload_manager.go | 22 ++++++++++ server/server_pairing.go | 44 +++++++++++++++++--- server/server_pairing_test.go | 75 +++++++++++++++++++++++++++++++++++ 8 files changed, 287 insertions(+), 33 deletions(-) create mode 100644 server/encryption.go create mode 100644 server/payload_manager.go create mode 100644 server/server_pairing_test.go diff --git a/server/client.go b/server/client.go index ae3a9fd4a..2e8cc1b44 100644 --- a/server/client.go +++ b/server/client.go @@ -1,23 +1,28 @@ package server import ( + "bytes" "crypto/ecdsa" "crypto/tls" "crypto/x509" "fmt" + "io/ioutil" "net/http" "net/url" ) -type Client struct { +type PairingClient struct { *http.Client baseAddress *url.URL certPEM []byte privateKey *ecdsa.PrivateKey + aesKey []byte + serverMode Mode + payload *PayloadManager } -func NewClient(c *ConnectionParams) (*Client, error) { +func NewPairingClient(c *ConnectionParams) (*PairingClient, error) { u, certPem, err := c.Generate() if err != nil { return nil, err @@ -40,10 +45,56 @@ func NewClient(c *ConnectionParams) (*Client, error) { }, } - return &Client{ + ek, err := makeEncryptionKey(c.privateKey) + if err != nil { + return nil, err + } + + return &PairingClient{ Client: &http.Client{Transport: tr}, baseAddress: u, certPEM: certPem, privateKey: c.privateKey, + aesKey: ek, + serverMode: c.serverMode, + payload: new(PayloadManager), }, nil } + +func (s *PairingClient) MountPayload(data []byte) { + s.payload.Mount(data) +} + +func (c *PairingClient) PairAccount() error { + switch c.serverMode { + case Receiving: + return c.sendAccountData() + case Sending: + return c.receiveAccountData() + default: + return fmt.Errorf("unrecognised server mode '%d'", c.serverMode) + } +} + +func (c *PairingClient) sendAccountData() error { + c.baseAddress.Path = pairingReceive + _, err := c.Post(c.baseAddress.String(), "application/octet-stream", bytes.NewBuffer(c.payload.ToSend())) + if err != nil { + return err + } + + return nil +} + +func (c *PairingClient) receiveAccountData() error { + c.baseAddress.Path = pairingSend + resp, err := c.Get(c.baseAddress.String()) + if err != nil { + return err + } + + content, _ := ioutil.ReadAll(resp.Body) + c.payload.Receive(content) + + return nil +} diff --git a/server/components_test.go b/server/components_test.go index af870120f..72a139df8 100644 --- a/server/components_test.go +++ b/server/components_test.go @@ -3,8 +3,11 @@ package server import ( "crypto/ecdsa" "crypto/elliptic" + "crypto/rand" + "crypto/tls" "encoding/asn1" "math/big" + "net" "testing" "time" @@ -70,3 +73,37 @@ func (tcc *TestCertComponents) SetupCertComponents(t *testing.T) { tcc.NotAfter = tcc.NotBefore.Add(time.Hour) } + +type TestPairingServerComponents struct { + EphemeralPK *ecdsa.PrivateKey + OutboundIP net.IP + CertTime time.Time + Cert tls.Certificate + PS *PairingServer +} + +func (tpsc *TestPairingServerComponents) SetupPairingServerComponents(t *testing.T) { + var err error + + // Get 3 key components for tls.cert generation + // 1) Ephemeral private key + tpsc.EphemeralPK, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + // 2) Device outbound IP address + tpsc.OutboundIP, err = GetOutboundIP() + require.NoError(t, err) + + // 3) NotBefore time + tpsc.CertTime = time.Now() + + // Generate tls.Certificate and Server + tpsc.Cert, _, err = GenerateCertFromKey(tpsc.EphemeralPK, tpsc.CertTime, tpsc.OutboundIP.String()) + require.NoError(t, err) + + tpsc.PS, err = NewPairingServer(&Config{ + PK: tpsc.EphemeralPK, + Cert: &tpsc.Cert, + Hostname: tpsc.OutboundIP.String()}) + require.NoError(t, err) +} diff --git a/server/encryption.go b/server/encryption.go new file mode 100644 index 000000000..58f59add7 --- /dev/null +++ b/server/encryption.go @@ -0,0 +1,11 @@ +package server + +import ( + "crypto/ecdsa" + + "github.com/status-im/status-go/protocol/common" +) + +func makeEncryptionKey(key *ecdsa.PrivateKey) ([]byte, error) { + return common.MakeECDHSharedKey(key, &key.PublicKey) +} diff --git a/server/handlers.go b/server/handlers.go index 0862f3ea5..0674a6256 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -1,7 +1,9 @@ package server import ( + "crypto/rand" "database/sql" + "io/ioutil" "net/http" "time" @@ -18,6 +20,11 @@ const ( imagesPath = basePath + "/images" audioPath = basePath + "/audio" ipfsPath = "/ipfs" + + // Handler routes for pairing + pairingBase = "/pairing" + pairingSend = pairingBase + "/send" + pairingReceive = pairingBase + "/receive" ) type HandlerPatternMap map[string]http.HandlerFunc @@ -133,3 +140,26 @@ func handleIPFS(downloader *ipfs.Downloader, logger *zap.Logger) func(w http.Res } } } + +func handlePairingReceive(ps *PairingServer) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + content, err := ioutil.ReadAll(r.Body) + ps.logger.Error("ioutil.ReadAll(r.Body)", zap.Error(err)) + ps.payload.Receive(content) + } +} + +func handlePairingSend(ps *PairingServer) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/octet-stream") + + b := make([]byte, 32) + _, err := rand.Read(b) + if err != nil { + w.Write([]byte(err.Error())) + } + + ps.payload.Mount(b) + w.Write(b) + } +} diff --git a/server/ips_test.go b/server/ips_test.go index 7e31a7aef..58845b7dd 100644 --- a/server/ips_test.go +++ b/server/ips_test.go @@ -1,10 +1,9 @@ package server import ( - "crypto/ecdsa" - "crypto/elliptic" "crypto/rand" "encoding/hex" + "github.com/stretchr/testify/suite" "io/ioutil" "net/http" "testing" @@ -13,6 +12,19 @@ import ( "github.com/stretchr/testify/require" ) +func TestGetOutboundIPSuite(t *testing.T) { + suite.Run(t, new(GetOutboundIPSuite)) +} + +type GetOutboundIPSuite struct { + suite.Suite + TestPairingServerComponents +} + +func (s *GetOutboundIPSuite) SetupSuite() { + s.SetupPairingServerComponents(s.T()) +} + func testHandler(t *testing.T) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { say, ok := r.URL.Query()["say"] @@ -37,35 +49,17 @@ func makeThingToSay() (string, error) { return hex.EncodeToString(b), nil } -func TestGetOutboundIPWithFullServerE2e(t *testing.T) { - // Get 3 key components for tls.cert generation - // 1) Ephemeral private key - pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) +func (goip *GetOutboundIPSuite) TestGetOutboundIPWithFullServerE2e(t *testing.T) { + goip.PS.SetHandlers(HandlerPatternMap{"/hello": testHandler(t)}) - // 2) Device outbound IP address - ip, err := GetOutboundIP() - require.NoError(t, err) - - // 3) NotBefore time - certTime := time.Now() - - // Generate tls.Certificate and Server - cert, _, err := GenerateCertFromKey(pk, certTime, ip.String()) - require.NoError(t, err) - - s := NewPairingServer(&Config{pk, &cert, ip.String(), Sending}) - - s.SetHandlers(HandlerPatternMap{"/hello": testHandler(t)}) - - err = s.Start() + err := goip.PS.Start() require.NoError(t, err) // Give time for the sever to be ready, hacky I know, I'll iron this out time.Sleep(100 * time.Millisecond) // Server generates a QR code connection string - cp, err := s.MakeConnectionParams() + cp, err := goip.PS.MakeConnectionParams() require.NoError(t, err) qr, err := cp.ToString() @@ -76,7 +70,7 @@ func TestGetOutboundIPWithFullServerE2e(t *testing.T) { err = ccp.FromString(qr) require.NoError(t, err) - c, err := NewClient(ccp) + c, err := NewPairingClient(ccp) require.NoError(t, err) thing, err := makeThingToSay() diff --git a/server/payload_manager.go b/server/payload_manager.go new file mode 100644 index 000000000..26abd3ea8 --- /dev/null +++ b/server/payload_manager.go @@ -0,0 +1,22 @@ +package server + +type PayloadManager struct { + toSend []byte + received []byte +} + +func (pm *PayloadManager) Mount(data []byte) { + pm.toSend = data +} + +func (pm *PayloadManager) Receive(data []byte) { + pm.received = data +} + +func (pm *PayloadManager) ToSend() []byte { + return pm.toSend +} + +func (pm *PayloadManager) Received() []byte { + return pm.received +} diff --git a/server/server_pairing.go b/server/server_pairing.go index c374486a5..2641be8a9 100644 --- a/server/server_pairing.go +++ b/server/server_pairing.go @@ -10,8 +10,10 @@ import ( type PairingServer struct { Server - pk *ecdsa.PrivateKey - mode Mode + pk *ecdsa.PrivateKey + aesKey []byte + mode Mode + payload *PayloadManager } type Config struct { @@ -22,13 +24,20 @@ type Config struct { } // NewPairingServer returns a *NewPairingServer init from the given *Config -func NewPairingServer(config *Config) *PairingServer { +func NewPairingServer(config *Config) (*PairingServer, error) { + ek, err := makeEncryptionKey(config.PK) + if err != nil { + return nil, err + } + return &PairingServer{Server: NewServer( config.Cert, config.Hostname, ), - pk: config.PK, - mode: config.Mode} + pk: config.PK, + aesKey: ek, + mode: config.Mode, + payload: new(PayloadManager)}, nil } // MakeConnectionParams generates a *ConnectionParams based on the Server's current state @@ -58,3 +67,28 @@ func (s *PairingServer) MakeConnectionParams() (*ConnectionParams, error) { return NewConnectionParams(netIP, s.port, s.pk, s.cert.Leaf.NotBefore, s.mode), nil } + +func (s *PairingServer) MountPayload(data []byte) { + s.payload.Mount(data) +} + +func (s *PairingServer) StartPairing() error { + switch s.mode { + case Receiving: + return s.startReceivingAccountData() + case Sending: + return s.startSendingAccountData() + default: + return fmt.Errorf("invalid server mode '%d'", s.mode) + } +} + +func (s *PairingServer) startReceivingAccountData() error { + s.SetHandlers(HandlerPatternMap{pairingReceive: handlePairingReceive(s)}) + return s.Start() +} + +func (s *PairingServer) startSendingAccountData() error { + s.SetHandlers(HandlerPatternMap{pairingSend: handlePairingSend(s)}) + return s.Start() +} diff --git a/server/server_pairing_test.go b/server/server_pairing_test.go new file mode 100644 index 000000000..d8de6dc85 --- /dev/null +++ b/server/server_pairing_test.go @@ -0,0 +1,75 @@ +package server + +import ( + "crypto/rand" + "testing" + "time" + + "github.com/stretchr/testify/suite" +) + +func TestPairingServerSuite(t *testing.T) { + suite.Run(t, new(PairingServerSuite)) +} + +type PairingServerSuite struct { + suite.Suite + TestPairingServerComponents +} + +func (s *PairingServerSuite) SetupSuite() { + s.SetupPairingServerComponents(s.T()) +} + +func (s *PairingServerSuite) TestPairingServer_StartPairing() { + modes := []Mode{ + Receiving, + Sending, + } + + for _, m := range modes { + s.PS.mode = m + + // Random payload + data := make([]byte, 32) + _, err := rand.Read(data) + s.Require().NoError(err) + + if m == Sending { + s.PS.MountPayload(data) + } + + err = s.PS.StartPairing() + s.Require().NoError(err) + + // Give time for the sever to be ready, hacky I know, I'll iron this out + time.Sleep(10 * time.Millisecond) + + cp, err := s.PS.MakeConnectionParams() + s.Require().NoError(err) + + qr, err := cp.ToString() + s.Require().NoError(err) + + // Client reads QR code and parses the connection string + ccp := new(ConnectionParams) + err = ccp.FromString(qr) + s.Require().NoError(err) + + c, err := NewPairingClient(ccp) + s.Require().NoError(err) + + if m == Receiving { + c.MountPayload(data) + } + + err = c.PairAccount() + s.Require().NoError(err) + + s.Require().Equal(s.PS.payload.ToSend(), c.payload.Received()) + s.Require().Equal(s.PS.payload.Received(), c.payload.ToSend()) + + // Reset the server's PayloadManager + s.PS.payload = new(PayloadManager) + } +}