More refactor to increase testability

This commit is contained in:
Samuel Hawksby-Robinson 2022-07-01 16:37:53 +01:00
parent 0e878d55d2
commit 215dbac09a
7 changed files with 243 additions and 121 deletions

View File

@ -9,21 +9,19 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"github.com/status-im/status-go/multiaccounts"
) )
type PairingClient struct { type PairingClient struct {
*http.Client *http.Client
baseAddress *url.URL baseAddress *url.URL
certPEM []byte certPEM []byte
privateKey *ecdsa.PrivateKey privateKey *ecdsa.PrivateKey
serverMode Mode serverMode Mode
payload *PairingPayloadManager PayloadManager PayloadManager
} }
func NewPairingClient(c *ConnectionParams, db *multiaccounts.Database) (*PairingClient, error) { func NewPairingClient(c *ConnectionParams, config *PairingPayloadManagerConfig) (*PairingClient, error) {
u, certPem, err := c.Generate() u, certPem, err := c.Generate()
if err != nil { if err != nil {
return nil, err return nil, err
@ -46,25 +44,21 @@ func NewPairingClient(c *ConnectionParams, db *multiaccounts.Database) (*Pairing
}, },
} }
pm, err := NewPairingPayloadManager(c.privateKey, db) pm, err := NewPairingPayloadManager(c.privateKey, config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &PairingClient{ return &PairingClient{
Client: &http.Client{Transport: tr}, Client: &http.Client{Transport: tr},
baseAddress: u, baseAddress: u,
certPEM: certPem, certPEM: certPem,
privateKey: c.privateKey, privateKey: c.privateKey,
serverMode: c.serverMode, serverMode: c.serverMode,
payload: pm, PayloadManager: pm,
}, nil }, nil
} }
func (c *PairingClient) MountPayload(data []byte) error {
return c.payload.pem.Mount(data)
}
func (c *PairingClient) PairAccount() error { func (c *PairingClient) PairAccount() error {
switch c.serverMode { switch c.serverMode {
case Receiving: case Receiving:
@ -78,7 +72,7 @@ func (c *PairingClient) PairAccount() error {
func (c *PairingClient) sendAccountData() error { func (c *PairingClient) sendAccountData() error {
c.baseAddress.Path = pairingReceive c.baseAddress.Path = pairingReceive
_, err := c.Post(c.baseAddress.String(), "application/octet-stream", bytes.NewBuffer(c.payload.pem.ToSend())) _, err := c.Post(c.baseAddress.String(), "application/octet-stream", bytes.NewBuffer(c.PayloadManager.ToSend()))
if err != nil { if err != nil {
return err return err
} }
@ -98,5 +92,5 @@ func (c *PairingClient) receiveAccountData() error {
return err return err
} }
return c.payload.pem.Receive(payload) return c.PayloadManager.Receive(payload)
} }

View File

@ -104,6 +104,44 @@ func (tpsc *TestPairingServerComponents) SetupPairingServerComponents(t *testing
tpsc.PS, err = NewPairingServer(&Config{ tpsc.PS, err = NewPairingServer(&Config{
PK: tpsc.EphemeralPK, PK: tpsc.EphemeralPK,
Cert: &tpsc.Cert, Cert: &tpsc.Cert,
Hostname: tpsc.OutboundIP.String()}, nil) Hostname: tpsc.OutboundIP.String()})
require.NoError(t, err) require.NoError(t, err)
} }
type MockEncryptOnlyPayloadManager struct {
pem *PayloadEncryptionManager
}
func NewMockEncryptOnlyPayloadManager(pk *ecdsa.PrivateKey) (*MockEncryptOnlyPayloadManager, error) {
pem, err := NewPayloadEncryptionManager(pk)
if err != nil {
return nil, err
}
return &MockEncryptOnlyPayloadManager{
pem: pem,
}, nil
}
func (m *MockEncryptOnlyPayloadManager) Mount() error {
// Make a random payload
data := make([]byte, 32)
_, err := rand.Read(data)
if err != nil {
return err
}
return m.pem.Encrypt(data)
}
func (m *MockEncryptOnlyPayloadManager) Receive(data []byte) error {
return m.pem.Decrypt(data)
}
func (m *MockEncryptOnlyPayloadManager) ToSend() []byte {
return m.pem.ToSend()
}
func (m *MockEncryptOnlyPayloadManager) Received() []byte {
return m.pem.Received()
}

View File

@ -147,9 +147,9 @@ func handlePairingReceive(ps *PairingServer) func(w http.ResponseWriter, r *http
ps.logger.Error("ioutil.ReadAll(r.Body)", zap.Error(err)) ps.logger.Error("ioutil.ReadAll(r.Body)", zap.Error(err))
} }
err = ps.payload.pem.Receive(payload) err = ps.PayloadManager.Receive(payload)
if err != nil { if err != nil {
ps.logger.Error("ps.payload.Receive(payload)", zap.Error(err)) ps.logger.Error("ps.PayloadManager.Receive(payload)", zap.Error(err))
} }
} }
} }
@ -157,9 +157,9 @@ func handlePairingReceive(ps *PairingServer) func(w http.ResponseWriter, r *http
func handlePairingSend(ps *PairingServer) func(w http.ResponseWriter, r *http.Request) { func handlePairingSend(ps *PairingServer) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/octet-stream") w.Header().Set("Content-Type", "application/octet-stream")
_, err := w.Write(ps.payload.pem.ToSend()) _, err := w.Write(ps.PayloadManager.ToSend())
if err != nil { if err != nil {
ps.logger.Error("w.Write(ps.payload.ToSend())", zap.Error(err)) ps.logger.Error("w.Write(ps.PayloadManager.ToSend())", zap.Error(err))
} }
} }
} }

View File

@ -19,13 +19,26 @@ import (
"github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/protobuf"
) )
type PayloadManager interface {
Mount() error
Receive(data []byte) error
ToSend() []byte
Received() []byte
}
type PairingPayloadManagerConfig struct {
DB *multiaccounts.Database
KeystorePath, KeyUID, Password string
}
// PairingPayloadManager is responsible for the whole lifecycle of a PairingPayload
type PairingPayloadManager struct { type PairingPayloadManager struct {
pem *PayloadEncryptionManager pem *PayloadEncryptionManager
ppm *PairingPayloadMarshaller ppm *PairingPayloadMarshaller
ppr *PairingPayloadRepository ppr PayloadRepository
} }
func NewPairingPayloadManager(pk *ecdsa.PrivateKey, db *multiaccounts.Database) (*PairingPayloadManager, error) { func NewPairingPayloadManager(pk *ecdsa.PrivateKey, config *PairingPayloadManagerConfig) (*PairingPayloadManager, error) {
pem, err := NewPayloadEncryptionManager(pk) pem, err := NewPayloadEncryptionManager(pk)
if err != nil { if err != nil {
return nil, err return nil, err
@ -34,17 +47,55 @@ func NewPairingPayloadManager(pk *ecdsa.PrivateKey, db *multiaccounts.Database)
return &PairingPayloadManager{ return &PairingPayloadManager{
pem: pem, pem: pem,
ppm: NewPairingPayloadMarshaller(), ppm: NewPairingPayloadMarshaller(),
ppr: NewPairingPayloadRepository(db), ppr: NewPairingPayloadRepository(config),
}, nil }, nil
} }
// EncryptionPayload represents the plain text and encrypted text of a Server's payload data func (ppm *PairingPayloadManager) Mount() error {
err := ppm.ppr.LoadFromSource()
if err != nil {
return err
}
ppm.ppm.LoadPayload(ppm.ppr.GetPayload())
pb, err := ppm.ppm.MarshalToProtobuf()
if err != nil {
return err
}
return ppm.pem.Encrypt(pb)
}
func (ppm *PairingPayloadManager) Receive(data []byte) error {
err := ppm.pem.Decrypt(data)
if err != nil {
return err
}
err = ppm.ppm.UnmarshalProtobuf(ppm.pem.Received())
if err != nil {
return err
}
ppm.ppr.LoadPayload(ppm.ppm.GetPayload())
return ppm.ppr.StoreToSource()
}
func (ppm *PairingPayloadManager) ToSend() []byte {
return ppm.pem.ToSend()
}
func (ppm *PairingPayloadManager) Received() []byte {
return ppm.pem.Received()
}
// EncryptionPayload represents the plain text and encrypted text of payload data
type EncryptionPayload struct { type EncryptionPayload struct {
plain []byte plain []byte
encrypted []byte encrypted []byte
} }
// PayloadEncryptionManager is responsible for encrypting and decrypting a Server's payload data // PayloadEncryptionManager is responsible for encrypting and decrypting payload data
type PayloadEncryptionManager struct { type PayloadEncryptionManager struct {
aesKey []byte aesKey []byte
toSend *EncryptionPayload toSend *EncryptionPayload
@ -60,7 +111,7 @@ func NewPayloadEncryptionManager(pk *ecdsa.PrivateKey) (*PayloadEncryptionManage
return &PayloadEncryptionManager{ek, new(EncryptionPayload), new(EncryptionPayload)}, nil return &PayloadEncryptionManager{ek, new(EncryptionPayload), new(EncryptionPayload)}, nil
} }
func (pem *PayloadEncryptionManager) Mount(data []byte) error { func (pem *PayloadEncryptionManager) Encrypt(data []byte) error {
ep, err := common.Encrypt(data, pem.aesKey, rand.Reader) ep, err := common.Encrypt(data, pem.aesKey, rand.Reader)
if err != nil { if err != nil {
return err return err
@ -71,7 +122,7 @@ func (pem *PayloadEncryptionManager) Mount(data []byte) error {
return nil return nil
} }
func (pem *PayloadEncryptionManager) Receive(data []byte) error { func (pem *PayloadEncryptionManager) Decrypt(data []byte) error {
pd, err := common.Decrypt(data, pem.aesKey) pd, err := common.Decrypt(data, pem.aesKey)
if err != nil { if err != nil {
return err return err
@ -111,10 +162,14 @@ func NewPairingPayloadMarshaller() *PairingPayloadMarshaller {
return &PairingPayloadMarshaller{PairingPayload: new(PairingPayload)} return &PairingPayloadMarshaller{PairingPayload: new(PairingPayload)}
} }
func (ppm *PairingPayloadMarshaller) Load(payload *PairingPayload) { func (ppm *PairingPayloadMarshaller) LoadPayload(payload *PairingPayload) {
ppm.PairingPayload = payload ppm.PairingPayload = payload
} }
func (ppm *PairingPayloadMarshaller) GetPayload() *PairingPayload {
return ppm.PairingPayload
}
func (ppm *PairingPayloadMarshaller) MarshalToProtobuf() ([]byte, error) { func (ppm *PairingPayloadMarshaller) MarshalToProtobuf() ([]byte, error) {
return proto.Marshal(&protobuf.LocalPairingPayload{ return proto.Marshal(&protobuf.LocalPairingPayload{
Keys: ppm.accountKeysToProtobuf(), Keys: ppm.accountKeysToProtobuf(),
@ -228,40 +283,66 @@ func (ppm *PairingPayloadMarshaller) multiaccountFromProtobuf(pbMultiAccount *pr
} }
} }
type PayloadHandler interface {
LoadPayload(*PairingPayload)
GetPayload() *PairingPayload
}
type PayloadRepository interface {
PayloadHandler
LoadFromSource() error
StoreToSource() error
}
// PairingPayloadRepository is responsible for loading, parsing, validating and storing PairingServer payload data // PairingPayloadRepository is responsible for loading, parsing, validating and storing PairingServer payload data
type PairingPayloadRepository struct { type PairingPayloadRepository struct {
*PairingPayload *PairingPayload
multiaccountsDB *multiaccounts.Database multiaccountsDB *multiaccounts.Database
keystorePath, keyUID string
} }
func NewPairingPayloadRepository(db *multiaccounts.Database) *PairingPayloadRepository { func NewPairingPayloadRepository(config *PairingPayloadManagerConfig) *PairingPayloadRepository {
return &PairingPayloadRepository{ ppr := &PairingPayloadRepository{
PairingPayload: new(PairingPayload), PairingPayload: new(PairingPayload),
multiaccountsDB: db,
} }
if config == nil {
return ppr
}
ppr.multiaccountsDB = config.DB
ppr.keystorePath = config.KeystorePath
ppr.keyUID = config.KeyUID
ppr.password = config.Password
return ppr
} }
func (ppr *PairingPayloadRepository) Load(payload *PairingPayload) { func (ppr *PairingPayloadRepository) LoadPayload(payload *PairingPayload) {
ppr.PairingPayload = payload ppr.PairingPayload = payload
} }
func (ppr *PairingPayloadRepository) LoadFromSource(keystorePath, keyUID, password string) error { func (ppr *PairingPayloadRepository) GetPayload() *PairingPayload {
err := ppr.loadKeys(keystorePath) return ppr.PairingPayload
}
func (ppr *PairingPayloadRepository) LoadFromSource() error {
err := ppr.loadKeys(ppr.keystorePath)
if err != nil { if err != nil {
return err return err
} }
err = ppr.validateKeys(password) err = ppr.validateKeys(ppr.password)
if err != nil { if err != nil {
return err return err
} }
ppr.multiaccount, err = ppr.multiaccountsDB.GetAccount(keyUID) ppr.multiaccount, err = ppr.multiaccountsDB.GetAccount(ppr.keyUID)
if err != nil { if err != nil {
return err return err
} }
ppr.password = password
return nil return nil
} }
@ -305,13 +386,13 @@ func (ppr *PairingPayloadRepository) loadKeys(keyStorePath string) error {
return nil return nil
} }
func (ppr *PairingPayloadRepository) StoreToSource(keystorePath, password string) error { func (ppr *PairingPayloadRepository) StoreToSource() error {
err := ppr.validateKeys(password) err := ppr.validateKeys(ppr.password)
if err != nil { if err != nil {
return err return err
} }
err = ppr.storeKeys(keystorePath) err = ppr.storeKeys(ppr.keystorePath)
if err != nil { if err != nil {
return err return err
} }

View File

@ -39,10 +39,8 @@ type PayloadMarshallerSuite struct {
teardown func() teardown func()
db1 *multiaccounts.Database config1 *PairingPayloadManagerConfig
db2 *multiaccounts.Database config2 *PairingPayloadManagerConfig
keystore1 string
keystore2 string
} }
func setupTestDB(t *testing.T) (*multiaccounts.Database, func()) { func setupTestDB(t *testing.T) (*multiaccounts.Database, func()) {
@ -104,22 +102,32 @@ func getFiles(t *testing.T, keyStorePath string) map[string][]byte {
} }
func (pms *PayloadMarshallerSuite) SetupTest() { func (pms *PayloadMarshallerSuite) SetupTest() {
var db1td func() db1, db1td := setupTestDB(pms.T())
var db2td func() db2, db2td := setupTestDB(pms.T())
var kstd func() keystore1, keystore2, kstd := makeKeystores(pms.T())
pms.db1, db1td = setupTestDB(pms.T())
pms.db2, db2td = setupTestDB(pms.T())
pms.keystore1, pms.keystore2, kstd = makeKeystores(pms.T())
pms.teardown = func() { pms.teardown = func() {
db1td() db1td()
db2td() db2td()
kstd() kstd()
} }
initKeys(pms.T(), pms.keystore1) initKeys(pms.T(), keystore1)
err := pms.db1.SaveAccount(expected) err := db1.SaveAccount(expected)
pms.Require().NoError(err) pms.Require().NoError(err)
pms.config1 = &PairingPayloadManagerConfig{
DB: db1,
KeystorePath: keystore1,
KeyUID: keyUID,
Password: password,
}
pms.config2 = &PairingPayloadManagerConfig{
DB: db2,
KeystorePath: keystore2,
KeyUID: keyUID,
Password: password,
}
} }
func (pms *PayloadMarshallerSuite) TearDownTest() { func (pms *PayloadMarshallerSuite) TearDownTest() {
@ -128,43 +136,43 @@ func (pms *PayloadMarshallerSuite) TearDownTest() {
func (pms *PayloadMarshallerSuite) TestPayloadMarshaller_LoadPayloads() { func (pms *PayloadMarshallerSuite) TestPayloadMarshaller_LoadPayloads() {
// Make and LoadFromSource PairingPayloadRepository 1 // Make and LoadFromSource PairingPayloadRepository 1
pm := NewPairingPayloadRepository(pms.db1) ppr := NewPairingPayloadRepository(pms.config1)
err := pm.LoadFromSource(pms.keystore1, keyUID, password) err := ppr.LoadFromSource()
pms.Require().NoError(err) pms.Require().NoError(err)
// TEST PairingPayloadRepository 1 LoadFromSource() // TEST PairingPayloadRepository 1 LoadFromSource()
pms.Require().Len(pm.keys, 2) pms.Require().Len(ppr.keys, 2)
pms.Require().Len(pm.keys[utils.GetAccount1PKFile()], 489) pms.Require().Len(ppr.keys[utils.GetAccount1PKFile()], 489)
pms.Require().Len(pm.keys[utils.GetAccount2PKFile()], 489) pms.Require().Len(ppr.keys[utils.GetAccount2PKFile()], 489)
h1 := sha256.New() h1 := sha256.New()
h1.Write(pm.keys[utils.GetAccount1PKFile()]) h1.Write(ppr.keys[utils.GetAccount1PKFile()])
pms.Require().Exactly(account1Hash, h1.Sum(nil)) pms.Require().Exactly(account1Hash, h1.Sum(nil))
h2 := sha256.New() h2 := sha256.New()
h2.Write(pm.keys[utils.GetAccount2PKFile()]) h2.Write(ppr.keys[utils.GetAccount2PKFile()])
pms.Require().Exactly(account2Hash, h2.Sum(nil)) pms.Require().Exactly(account2Hash, h2.Sum(nil))
pms.Require().Exactly(expected.ColorHash, pm.multiaccount.ColorHash) pms.Require().Exactly(expected.ColorHash, ppr.multiaccount.ColorHash)
pms.Require().Exactly(expected.ColorID, pm.multiaccount.ColorID) pms.Require().Exactly(expected.ColorID, ppr.multiaccount.ColorID)
pms.Require().Exactly(expected.Identicon, pm.multiaccount.Identicon) pms.Require().Exactly(expected.Identicon, ppr.multiaccount.Identicon)
pms.Require().Exactly(expected.KeycardPairing, pm.multiaccount.KeycardPairing) pms.Require().Exactly(expected.KeycardPairing, ppr.multiaccount.KeycardPairing)
pms.Require().Exactly(expected.KeyUID, pm.multiaccount.KeyUID) pms.Require().Exactly(expected.KeyUID, ppr.multiaccount.KeyUID)
pms.Require().Exactly(expected.Name, pm.multiaccount.Name) pms.Require().Exactly(expected.Name, ppr.multiaccount.Name)
pms.Require().Exactly(expected.Timestamp, pm.multiaccount.Timestamp) pms.Require().Exactly(expected.Timestamp, ppr.multiaccount.Timestamp)
pms.Require().Len(pm.multiaccount.Images, 2) pms.Require().Len(ppr.multiaccount.Images, 2)
pms.Require().Equal(password, pm.password) pms.Require().Equal(password, ppr.password)
} }
func (pms *PayloadMarshallerSuite) TestPayloadMarshaller_MarshalToProtobuf() { func (pms *PayloadMarshallerSuite) TestPayloadMarshaller_MarshalToProtobuf() {
// Make and LoadFromSource PairingPayloadRepository 1 // Make and LoadFromSource PairingPayloadRepository 1
ppr := NewPairingPayloadRepository(pms.db1) ppr := NewPairingPayloadRepository(pms.config1)
err := ppr.LoadFromSource(pms.keystore1, keyUID, password) err := ppr.LoadFromSource()
pms.Require().NoError(err) pms.Require().NoError(err)
// Make and Load PairingPayloadMarshaller 1 // Make and Load PairingPayloadMarshaller 1
ppm := NewPairingPayloadMarshaller() ppm := NewPairingPayloadMarshaller()
ppm.Load(ppr.PairingPayload) ppm.LoadPayload(ppr.GetPayload())
// TEST PairingPayloadMarshaller 1 MarshalToProtobuf() // TEST PairingPayloadMarshaller 1 MarshalToProtobuf()
pb, err := ppm.MarshalToProtobuf() pb, err := ppm.MarshalToProtobuf()
@ -178,13 +186,13 @@ func (pms *PayloadMarshallerSuite) TestPayloadMarshaller_MarshalToProtobuf() {
func (pms *PayloadMarshallerSuite) TestPayloadMarshaller_UnmarshalProtobuf() { func (pms *PayloadMarshallerSuite) TestPayloadMarshaller_UnmarshalProtobuf() {
// Make and LoadFromSource PairingPayloadRepository 1 // Make and LoadFromSource PairingPayloadRepository 1
ppr := NewPairingPayloadRepository(pms.db1) ppr := NewPairingPayloadRepository(pms.config1)
err := ppr.LoadFromSource(pms.keystore1, keyUID, password) err := ppr.LoadFromSource()
pms.Require().NoError(err) pms.Require().NoError(err)
// Make and Load PairingPayloadMarshaller 1 // Make and Load PairingPayloadMarshaller 1
ppm := NewPairingPayloadMarshaller() ppm := NewPairingPayloadMarshaller()
ppm.Load(ppr.PairingPayload) ppm.LoadPayload(ppr.GetPayload())
pb, err := ppm.MarshalToProtobuf() pb, err := ppm.MarshalToProtobuf()
pms.Require().NoError(err) pms.Require().NoError(err)
@ -226,13 +234,13 @@ func (pms *PayloadMarshallerSuite) TestPayloadMarshaller_UnmarshalProtobuf() {
func (pms *PayloadMarshallerSuite) TestPayloadMarshaller_StorePayloads() { func (pms *PayloadMarshallerSuite) TestPayloadMarshaller_StorePayloads() {
// Make and LoadFromSource PairingPayloadRepository 1 // Make and LoadFromSource PairingPayloadRepository 1
ppr := NewPairingPayloadRepository(pms.db1) ppr := NewPairingPayloadRepository(pms.config1)
err := ppr.LoadFromSource(pms.keystore1, keyUID, password) err := ppr.LoadFromSource()
pms.Require().NoError(err) pms.Require().NoError(err)
// Make and Load PairingPayloadMarshaller 1 // Make and Load PairingPayloadMarshaller 1
ppm := NewPairingPayloadMarshaller() ppm := NewPairingPayloadMarshaller()
ppm.Load(ppr.PairingPayload) ppm.LoadPayload(ppr.PairingPayload)
pb, err := ppm.MarshalToProtobuf() pb, err := ppm.MarshalToProtobuf()
pms.Require().NoError(err) pms.Require().NoError(err)
@ -244,14 +252,14 @@ func (pms *PayloadMarshallerSuite) TestPayloadMarshaller_StorePayloads() {
pms.Require().NoError(err) pms.Require().NoError(err)
// Make and Load PairingPayloadRepository 2 // Make and Load PairingPayloadRepository 2
ppr2 := NewPairingPayloadRepository(pms.db2) ppr2 := NewPairingPayloadRepository(pms.config2)
ppr2.Load(ppm2.PairingPayload) ppr2.LoadPayload(ppm2.PairingPayload)
err = ppr2.StoreToSource(pms.keystore2, password) err = ppr2.StoreToSource()
pms.Require().NoError(err) pms.Require().NoError(err)
// TEST PairingPayloadRepository 2 StoreToSource() // TEST PairingPayloadRepository 2 StoreToSource()
keys := getFiles(pms.T(), pms.keystore2) keys := getFiles(pms.T(), pms.config2.KeystorePath)
pms.Require().Len(keys, 2) pms.Require().Len(keys, 2)
pms.Require().Len(keys[utils.GetAccount1PKFile()], 489) pms.Require().Len(keys[utils.GetAccount1PKFile()], 489)
@ -265,7 +273,7 @@ func (pms *PayloadMarshallerSuite) TestPayloadMarshaller_StorePayloads() {
h2.Write(keys[utils.GetAccount2PKFile()]) h2.Write(keys[utils.GetAccount2PKFile()])
pms.Require().Exactly(account2Hash, h2.Sum(nil)) pms.Require().Exactly(account2Hash, h2.Sum(nil))
acc, err := pms.db2.GetAccount(keyUID) acc, err := pms.config2.DB.GetAccount(keyUID)
pms.Require().NoError(err) pms.Require().NoError(err)
pms.Require().Exactly(expected.ColorHash, acc.ColorHash) pms.Require().Exactly(expected.ColorHash, acc.ColorHash)

View File

@ -5,28 +5,30 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
"github.com/status-im/status-go/multiaccounts"
) )
type PairingServer struct { type PairingServer struct {
Server Server
pk *ecdsa.PrivateKey pk *ecdsa.PrivateKey
mode Mode mode Mode
payload *PairingPayloadManager PayloadManager PayloadManager
} }
type Config struct { type Config struct {
// Connection fields
PK *ecdsa.PrivateKey PK *ecdsa.PrivateKey
Cert *tls.Certificate Cert *tls.Certificate
Hostname string Hostname string
Mode Mode Mode Mode
// Payload management fields
*PairingPayloadManagerConfig
} }
// NewPairingServer returns a *PairingServer init from the given *Config // NewPairingServer returns a *PairingServer init from the given *Config
func NewPairingServer(config *Config, db *multiaccounts.Database) (*PairingServer, error) { func NewPairingServer(config *Config) (*PairingServer, error) {
pm, err := NewPairingPayloadManager(config.PK, db) pm, err := NewPairingPayloadManager(config.PK, config.PairingPayloadManagerConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -35,9 +37,9 @@ func NewPairingServer(config *Config, db *multiaccounts.Database) (*PairingServe
config.Cert, config.Cert,
config.Hostname, config.Hostname,
), ),
pk: config.PK, pk: config.PK,
mode: config.Mode, mode: config.Mode,
payload: pm}, nil PayloadManager: pm}, nil
} }
// MakeConnectionParams generates a *ConnectionParams based on the Server's current state // MakeConnectionParams generates a *ConnectionParams based on the Server's current state
@ -68,10 +70,6 @@ func (s *PairingServer) MakeConnectionParams() (*ConnectionParams, error) {
return NewConnectionParams(netIP, s.port, s.pk, s.cert.Leaf.NotBefore, s.mode), nil return NewConnectionParams(netIP, s.port, s.pk, s.cert.Leaf.NotBefore, s.mode), nil
} }
func (s *PairingServer) MountPayload(data []byte) error {
return s.payload.pem.Mount(data)
}
func (s *PairingServer) StartPairing() error { func (s *PairingServer) StartPairing() error {
switch s.mode { switch s.mode {
case Receiving: case Receiving:

View File

@ -1,7 +1,6 @@
package server package server
import ( import (
"crypto/rand"
"testing" "testing"
"time" "time"
@ -22,6 +21,11 @@ func (s *PairingServerSuite) SetupSuite() {
} }
func (s *PairingServerSuite) TestPairingServer_StartPairing() { func (s *PairingServerSuite) TestPairingServer_StartPairing() {
// Replace PairingServer.PayloadManager with a MockEncryptOnlyPayloadManager
pm, err := NewMockEncryptOnlyPayloadManager(s.EphemeralPK)
s.Require().NoError(err)
s.PS.PayloadManager = pm
modes := []Mode{ modes := []Mode{
Receiving, Receiving,
Sending, Sending,
@ -30,13 +34,8 @@ func (s *PairingServerSuite) TestPairingServer_StartPairing() {
for _, m := range modes { for _, m := range modes {
s.PS.mode = m s.PS.mode = m
// Random payload
data := make([]byte, 32)
_, err := rand.Read(data)
s.Require().NoError(err)
if m == Sending { if m == Sending {
err := s.PS.MountPayload(data) err := s.PS.PayloadManager.Mount()
s.Require().NoError(err) s.Require().NoError(err)
} }
@ -60,8 +59,12 @@ func (s *PairingServerSuite) TestPairingServer_StartPairing() {
c, err := NewPairingClient(ccp, nil) c, err := NewPairingClient(ccp, nil)
s.Require().NoError(err) s.Require().NoError(err)
// Replace PairingClient.PayloadManager with a MockEncryptOnlyPayloadManager
c.PayloadManager, err = NewMockEncryptOnlyPayloadManager(s.EphemeralPK)
s.Require().NoError(err)
if m == Receiving { if m == Receiving {
err := c.MountPayload(data) err := c.PayloadManager.Mount()
s.Require().NoError(err) s.Require().NoError(err)
} }
@ -70,18 +73,18 @@ func (s *PairingServerSuite) TestPairingServer_StartPairing() {
switch m { switch m {
case Receiving: case Receiving:
s.Require().Equal(data, s.PS.payload.pem.Received()) s.Require().Equal(c.PayloadManager.(*MockEncryptOnlyPayloadManager).pem.toSend.plain, s.PS.PayloadManager.Received())
s.Require().Equal(s.PS.payload.pem.received.encrypted, c.payload.pem.toSend.encrypted) s.Require().Equal(s.PS.PayloadManager.(*MockEncryptOnlyPayloadManager).pem.received.encrypted, c.PayloadManager.(*MockEncryptOnlyPayloadManager).pem.toSend.encrypted)
s.Require().Nil(s.PS.payload.pem.ToSend()) s.Require().Nil(s.PS.PayloadManager.ToSend())
s.Require().Nil(c.payload.pem.Received()) s.Require().Nil(c.PayloadManager.Received())
case Sending: case Sending:
s.Require().Equal(c.payload.pem.Received(), data) s.Require().Equal(c.PayloadManager.Received(), s.PS.PayloadManager.(*MockEncryptOnlyPayloadManager).pem.toSend.plain)
s.Require().Equal(c.payload.pem.received.encrypted, s.PS.payload.pem.toSend.encrypted) s.Require().Equal(c.PayloadManager.(*MockEncryptOnlyPayloadManager).pem.received.encrypted, s.PS.PayloadManager.(*MockEncryptOnlyPayloadManager).pem.toSend.encrypted)
s.Require().Nil(c.payload.pem.ToSend()) s.Require().Nil(c.PayloadManager.ToSend())
s.Require().Nil(s.PS.payload.pem.Received()) s.Require().Nil(s.PS.PayloadManager.Received())
} }
// Reset the server's PayloadEncryptionManager // Reset the server's PayloadEncryptionManager
s.PS.payload.pem.ResetPayload() s.PS.PayloadManager.(*MockEncryptOnlyPayloadManager).pem.ResetPayload()
} }
} }