Add tests for multi-device and refactor encryption service config (#1277)
This commit is contained in:
parent
aac706fe4c
commit
6112ca0289
|
@ -19,39 +19,50 @@ import (
|
||||||
|
|
||||||
var ErrSessionNotFound = errors.New("session not found")
|
var ErrSessionNotFound = errors.New("session not found")
|
||||||
|
|
||||||
// Max number of installations we keep synchronized.
|
|
||||||
const maxInstallations = 5
|
|
||||||
|
|
||||||
// If we have no bundles, we use a constant so that the message can reach any device.
|
// If we have no bundles, we use a constant so that the message can reach any device.
|
||||||
const noInstallationID = "none"
|
const noInstallationID = "none"
|
||||||
|
|
||||||
// How many consecutive messages can be skipped in the receiving chain.
|
|
||||||
const maxSkip = 1000
|
|
||||||
|
|
||||||
// Any message with seqNo <= currentSeq - maxKeep will be deleted.
|
|
||||||
const maxKeep = 3000
|
|
||||||
|
|
||||||
// How many keys do we store in total per session.
|
|
||||||
const maxMessageKeysPerSession = 2000
|
|
||||||
|
|
||||||
// EncryptionService defines a service that is responsible for the encryption aspect of the protocol.
|
// EncryptionService defines a service that is responsible for the encryption aspect of the protocol.
|
||||||
type EncryptionService struct {
|
type EncryptionService struct {
|
||||||
log log.Logger
|
log log.Logger
|
||||||
persistence PersistenceService
|
persistence PersistenceService
|
||||||
installationID string
|
config EncryptionServiceConfig
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EncryptionServiceConfig struct {
|
||||||
|
InstallationID string
|
||||||
|
// Max number of installations we keep synchronized.
|
||||||
|
MaxInstallations int
|
||||||
|
// How many consecutive messages can be skipped in the receiving chain.
|
||||||
|
MaxSkip int
|
||||||
|
// Any message with seqNo <= currentSeq - maxKeep will be deleted.
|
||||||
|
MaxKeep int
|
||||||
|
// How many keys do we store in total per session.
|
||||||
|
MaxMessageKeysPerSession int
|
||||||
|
}
|
||||||
|
|
||||||
type IdentityAndIDPair [2]string
|
type IdentityAndIDPair [2]string
|
||||||
|
|
||||||
|
// DefaultEncryptionServiceConfig returns the default values used by the encryption service
|
||||||
|
func DefaultEncryptionServiceConfig(installationID string) EncryptionServiceConfig {
|
||||||
|
return EncryptionServiceConfig{
|
||||||
|
MaxInstallations: 5,
|
||||||
|
MaxSkip: 1000,
|
||||||
|
MaxKeep: 3000,
|
||||||
|
MaxMessageKeysPerSession: 2000,
|
||||||
|
InstallationID: installationID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewEncryptionService creates a new EncryptionService instance.
|
// NewEncryptionService creates a new EncryptionService instance.
|
||||||
func NewEncryptionService(p PersistenceService, installationID string) *EncryptionService {
|
func NewEncryptionService(p PersistenceService, config EncryptionServiceConfig) *EncryptionService {
|
||||||
logger := log.New("package", "status-go/services/sshext.chat")
|
logger := log.New("package", "status-go/services/sshext.chat")
|
||||||
logger.Info("Initialized encryption service", "installationID", installationID)
|
logger.Info("Initialized encryption service", "installationID", config.InstallationID)
|
||||||
return &EncryptionService{
|
return &EncryptionService{
|
||||||
log: logger,
|
log: logger,
|
||||||
persistence: p,
|
persistence: p,
|
||||||
installationID: installationID,
|
config: config,
|
||||||
mutex: sync.Mutex{},
|
mutex: sync.Mutex{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -65,16 +76,30 @@ func (s *EncryptionService) keyFromActiveX3DH(theirIdentityKey []byte, theirSign
|
||||||
return sharedKey, ephemeralPubKey, nil
|
return sharedKey, ephemeralPubKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *EncryptionService) getDRSession(id []byte) (dr.Session, error) {
|
||||||
|
sessionStorage := s.persistence.GetSessionStorage()
|
||||||
|
return dr.Load(
|
||||||
|
id,
|
||||||
|
sessionStorage,
|
||||||
|
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
||||||
|
dr.WithMaxSkip(s.config.MaxSkip),
|
||||||
|
dr.WithMaxKeep(s.config.MaxKeep),
|
||||||
|
dr.WithMaxMessageKeysPerSession(s.config.MaxMessageKeysPerSession),
|
||||||
|
dr.WithCrypto(crypto.EthereumCrypto{}),
|
||||||
|
)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// CreateBundle retrieves or creates an X3DH bundle given a private key
|
// CreateBundle retrieves or creates an X3DH bundle given a private key
|
||||||
func (s *EncryptionService) CreateBundle(privateKey *ecdsa.PrivateKey) (*Bundle, error) {
|
func (s *EncryptionService) CreateBundle(privateKey *ecdsa.PrivateKey) (*Bundle, error) {
|
||||||
ourIdentityKeyC := ecrypto.CompressPubkey(&privateKey.PublicKey)
|
ourIdentityKeyC := ecrypto.CompressPubkey(&privateKey.PublicKey)
|
||||||
|
|
||||||
installationIDs, err := s.persistence.GetActiveInstallations(maxInstallations-1, ourIdentityKeyC)
|
installationIDs, err := s.persistence.GetActiveInstallations(s.config.MaxInstallations-1, ourIdentityKeyC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
installationIDs = append(installationIDs, s.installationID)
|
installationIDs = append(installationIDs, s.config.InstallationID)
|
||||||
|
|
||||||
bundleContainer, err := s.persistence.GetAnyPrivateBundle(ourIdentityKeyC, installationIDs)
|
bundleContainer, err := s.persistence.GetAnyPrivateBundle(ourIdentityKeyC, installationIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -98,7 +123,7 @@ func (s *EncryptionService) CreateBundle(privateKey *ecdsa.PrivateKey) (*Bundle,
|
||||||
|
|
||||||
// needs transaction/mutex to avoid creating multiple bundles
|
// needs transaction/mutex to avoid creating multiple bundles
|
||||||
// although not a problem
|
// although not a problem
|
||||||
bundleContainer, err = NewBundleContainer(privateKey, s.installationID)
|
bundleContainer, err = NewBundleContainer(privateKey, s.config.InstallationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -182,7 +207,7 @@ func (s *EncryptionService) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey,
|
||||||
fromOurIdentity := identity != myIdentityStr
|
fromOurIdentity := identity != myIdentityStr
|
||||||
|
|
||||||
for installationID := range signedPreKeys {
|
for installationID := range signedPreKeys {
|
||||||
if installationID != s.installationID {
|
if installationID != s.config.InstallationID {
|
||||||
installationIDs = append(installationIDs, installationID)
|
installationIDs = append(installationIDs, installationID)
|
||||||
response = append(response, IdentityAndIDPair{identity, installationID})
|
response = append(response, IdentityAndIDPair{identity, installationID})
|
||||||
}
|
}
|
||||||
|
@ -204,7 +229,7 @@ func (s *EncryptionService) DecryptPayload(myIdentityKey *ecdsa.PrivateKey, thei
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
msg := msgs[s.installationID]
|
msg := msgs[s.config.InstallationID]
|
||||||
if msg == nil {
|
if msg == nil {
|
||||||
msg = msgs[noInstallationID]
|
msg = msgs[noInstallationID]
|
||||||
}
|
}
|
||||||
|
@ -293,9 +318,9 @@ func (s *EncryptionService) createNewSession(drInfo *RatchetInfo, sk [32]byte, k
|
||||||
keyPair,
|
keyPair,
|
||||||
s.persistence.GetSessionStorage(),
|
s.persistence.GetSessionStorage(),
|
||||||
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
||||||
dr.WithMaxSkip(maxSkip),
|
dr.WithMaxSkip(s.config.MaxSkip),
|
||||||
dr.WithMaxKeep(maxKeep),
|
dr.WithMaxKeep(s.config.MaxKeep),
|
||||||
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
|
dr.WithMaxMessageKeysPerSession(s.config.MaxMessageKeysPerSession),
|
||||||
dr.WithCrypto(crypto.EthereumCrypto{}))
|
dr.WithCrypto(crypto.EthereumCrypto{}))
|
||||||
} else {
|
} else {
|
||||||
session, err = dr.NewWithRemoteKey(
|
session, err = dr.NewWithRemoteKey(
|
||||||
|
@ -304,9 +329,9 @@ func (s *EncryptionService) createNewSession(drInfo *RatchetInfo, sk [32]byte, k
|
||||||
keyPair.PubKey,
|
keyPair.PubKey,
|
||||||
s.persistence.GetSessionStorage(),
|
s.persistence.GetSessionStorage(),
|
||||||
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
||||||
dr.WithMaxSkip(maxSkip),
|
dr.WithMaxSkip(s.config.MaxSkip),
|
||||||
dr.WithMaxKeep(maxKeep),
|
dr.WithMaxKeep(s.config.MaxKeep),
|
||||||
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
|
dr.WithMaxMessageKeysPerSession(s.config.MaxMessageKeysPerSession),
|
||||||
dr.WithCrypto(crypto.EthereumCrypto{}))
|
dr.WithCrypto(crypto.EthereumCrypto{}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -327,17 +352,9 @@ func (s *EncryptionService) encryptUsingDR(theirIdentityKey *ecdsa.PublicKey, dr
|
||||||
PubKey: publicKey,
|
PubKey: publicKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionStorage := s.persistence.GetSessionStorage()
|
|
||||||
// Load session from store first
|
// Load session from store first
|
||||||
session, err = dr.Load(
|
session, err = s.getDRSession(drInfo.ID)
|
||||||
drInfo.ID,
|
|
||||||
sessionStorage,
|
|
||||||
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
|
||||||
dr.WithMaxSkip(maxSkip),
|
|
||||||
dr.WithMaxKeep(maxKeep),
|
|
||||||
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
|
|
||||||
dr.WithCrypto(crypto.EthereumCrypto{}),
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -379,16 +396,7 @@ func (s *EncryptionService) decryptUsingDR(theirIdentityKey *ecdsa.PublicKey, dr
|
||||||
PubKey: publicKey,
|
PubKey: publicKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionStorage := s.persistence.GetSessionStorage()
|
session, err = s.getDRSession(drInfo.ID)
|
||||||
session, err = dr.Load(
|
|
||||||
drInfo.ID,
|
|
||||||
sessionStorage,
|
|
||||||
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
|
||||||
dr.WithMaxSkip(maxSkip),
|
|
||||||
dr.WithMaxKeep(maxKeep),
|
|
||||||
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
|
|
||||||
dr.WithCrypto(crypto.EthereumCrypto{}),
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -447,7 +455,7 @@ func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, my
|
||||||
|
|
||||||
theirIdentityKeyC := ecrypto.CompressPubkey(theirIdentityKey)
|
theirIdentityKeyC := ecrypto.CompressPubkey(theirIdentityKey)
|
||||||
|
|
||||||
installationIDs, err := s.persistence.GetActiveInstallations(maxInstallations, theirIdentityKeyC)
|
installationIDs, err := s.persistence.GetActiveInstallations(s.config.MaxInstallations, theirIdentityKeyC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -466,7 +474,7 @@ func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, my
|
||||||
response := make(map[string]*DirectMessageProtocol)
|
response := make(map[string]*DirectMessageProtocol)
|
||||||
|
|
||||||
for installationID, signedPreKeyContainer := range theirBundle.GetSignedPreKeys() {
|
for installationID, signedPreKeyContainer := range theirBundle.GetSignedPreKeys() {
|
||||||
if s.installationID == installationID {
|
if s.config.InstallationID == installationID {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -516,7 +524,7 @@ func (s *EncryptionService) EncryptPayload(theirIdentityKey *ecdsa.PublicKey, my
|
||||||
Id: theirSignedPreKey,
|
Id: theirSignedPreKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
drInfo, err = s.persistence.GetAnyRatchetInfo(theirIdentityKeyC, installationID)
|
drInfo, err = s.persistence.GetRatchetInfo(theirSignedPreKey, theirIdentityKeyC, installationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package chat
|
package chat
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"fmt"
|
||||||
"github.com/ethereum/go-ethereum/crypto"
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
"os"
|
"os"
|
||||||
|
@ -8,114 +10,106 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
aliceUser = "alice"
|
||||||
|
bobUser = "bob"
|
||||||
|
)
|
||||||
|
|
||||||
func TestEncryptionServiceMultiDeviceSuite(t *testing.T) {
|
func TestEncryptionServiceMultiDeviceSuite(t *testing.T) {
|
||||||
suite.Run(t, new(EncryptionServiceMultiDeviceSuite))
|
suite.Run(t, new(EncryptionServiceMultiDeviceSuite))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type serviceAndKey struct {
|
||||||
|
encryptionServices []*EncryptionService
|
||||||
|
key *ecdsa.PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
type EncryptionServiceMultiDeviceSuite struct {
|
type EncryptionServiceMultiDeviceSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
alice1 *EncryptionService
|
services map[string]*serviceAndKey
|
||||||
bob1 *EncryptionService
|
}
|
||||||
alice2 *EncryptionService
|
|
||||||
bob2 *EncryptionService
|
func setupUser(user string, s *EncryptionServiceMultiDeviceSuite, n int) error {
|
||||||
alice3 *EncryptionService
|
key, err := crypto.GenerateKey()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.services[user] = &serviceAndKey{
|
||||||
|
key: key,
|
||||||
|
encryptionServices: make([]*EncryptionService, n),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
installationID := fmt.Sprintf("%s%d", user, i+1)
|
||||||
|
dbPath := fmt.Sprintf("/tmp/%s.db", installationID)
|
||||||
|
|
||||||
|
os.Remove(dbPath)
|
||||||
|
|
||||||
|
persistence, err := NewSQLLitePersistence(dbPath, "key")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config := DefaultEncryptionServiceConfig(installationID)
|
||||||
|
config.MaxInstallations = n - 1
|
||||||
|
|
||||||
|
s.services[user].encryptionServices[i] = NewEncryptionService(persistence, config)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *EncryptionServiceMultiDeviceSuite) SetupTest() {
|
func (s *EncryptionServiceMultiDeviceSuite) SetupTest() {
|
||||||
const (
|
s.services = make(map[string]*serviceAndKey)
|
||||||
aliceDBPath1 = "/tmp/alice1.db"
|
err := setupUser(aliceUser, s, 4)
|
||||||
aliceDBKey1 = "alice1"
|
s.Require().NoError(err)
|
||||||
aliceDBPath2 = "/tmp/alice2.db"
|
|
||||||
aliceDBKey2 = "alice2"
|
|
||||||
aliceDBPath3 = "/tmp/alice3.db"
|
|
||||||
aliceDBKey3 = "alice3"
|
|
||||||
bobDBPath1 = "/tmp/bob1.db"
|
|
||||||
bobDBKey1 = "bob1"
|
|
||||||
bobDBPath2 = "/tmp/bob2.db"
|
|
||||||
bobDBKey2 = "bob2"
|
|
||||||
)
|
|
||||||
|
|
||||||
os.Remove(aliceDBPath1)
|
err = setupUser(bobUser, s, 4)
|
||||||
os.Remove(bobDBPath1)
|
s.Require().NoError(err)
|
||||||
os.Remove(aliceDBPath2)
|
|
||||||
os.Remove(bobDBPath2)
|
|
||||||
os.Remove(aliceDBPath3)
|
|
||||||
|
|
||||||
alicePersistence1, err := NewSQLLitePersistence(aliceDBPath1, aliceDBKey1)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
alicePersistence2, err := NewSQLLitePersistence(aliceDBPath2, aliceDBKey2)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
alicePersistence3, err := NewSQLLitePersistence(aliceDBPath3, aliceDBKey3)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bobPersistence1, err := NewSQLLitePersistence(bobDBPath1, bobDBKey1)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bobPersistence2, err := NewSQLLitePersistence(bobDBPath2, bobDBKey2)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.alice1 = NewEncryptionService(alicePersistence1, "alice1")
|
|
||||||
s.bob1 = NewEncryptionService(bobPersistence1, "bob1")
|
|
||||||
|
|
||||||
s.alice2 = NewEncryptionService(alicePersistence2, "alice2")
|
|
||||||
s.bob2 = NewEncryptionService(bobPersistence2, "bob2")
|
|
||||||
|
|
||||||
s.alice3 = NewEncryptionService(alicePersistence3, "alice3")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundle() {
|
func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundle() {
|
||||||
aliceKey, err := crypto.GenerateKey()
|
aliceKey := s.services[aliceUser].key
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
alice2Bundle, err := s.alice2.CreateBundle(aliceKey)
|
alice2Bundle, err := s.services[aliceUser].encryptionServices[1].CreateBundle(aliceKey)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
alice2Identity, err := ExtractIdentity(alice2Bundle)
|
alice2Identity, err := ExtractIdentity(alice2Bundle)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
alice3Bundle, err := s.alice3.CreateBundle(aliceKey)
|
alice3Bundle, err := s.services[aliceUser].encryptionServices[2].CreateBundle(aliceKey)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
alice3Identity, err := ExtractIdentity(alice2Bundle)
|
alice3Identity, err := ExtractIdentity(alice2Bundle)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
// Add alice2 bundle
|
// Add alice2 bundle
|
||||||
response, err := s.alice1.ProcessPublicBundle(aliceKey, alice2Bundle)
|
response, err := s.services[aliceUser].encryptionServices[0].ProcessPublicBundle(aliceKey, alice2Bundle)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Equal(IdentityAndIDPair{alice2Identity, "alice2"}, response[0])
|
s.Require().Equal(IdentityAndIDPair{alice2Identity, "alice2"}, response[0])
|
||||||
|
|
||||||
// Add alice3 bundle
|
// Add alice3 bundle
|
||||||
response, err = s.alice1.ProcessPublicBundle(aliceKey, alice3Bundle)
|
response, err = s.services[aliceUser].encryptionServices[0].ProcessPublicBundle(aliceKey, alice3Bundle)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Equal(IdentityAndIDPair{alice3Identity, "alice3"}, response[0])
|
s.Require().Equal(IdentityAndIDPair{alice3Identity, "alice3"}, response[0])
|
||||||
|
|
||||||
// No installation is enabled
|
// No installation is enabled
|
||||||
alice1MergedBundle1, err := s.alice1.CreateBundle(aliceKey)
|
alice1MergedBundle1, err := s.services[aliceUser].encryptionServices[0].CreateBundle(aliceKey)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
s.Require().Equal(1, len(alice1MergedBundle1.GetSignedPreKeys()))
|
s.Require().Equal(1, len(alice1MergedBundle1.GetSignedPreKeys()))
|
||||||
s.Require().NotNil(alice1MergedBundle1.GetSignedPreKeys()["alice1"])
|
s.Require().NotNil(alice1MergedBundle1.GetSignedPreKeys()["alice1"])
|
||||||
|
|
||||||
// We enable the installations
|
// We enable the installations
|
||||||
err = s.alice1.EnableInstallation(&aliceKey.PublicKey, "alice2")
|
err = s.services[aliceUser].encryptionServices[0].EnableInstallation(&aliceKey.PublicKey, "alice2")
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
err = s.alice1.EnableInstallation(&aliceKey.PublicKey, "alice3")
|
err = s.services[aliceUser].encryptionServices[0].EnableInstallation(&aliceKey.PublicKey, "alice3")
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
alice1MergedBundle2, err := s.alice1.CreateBundle(aliceKey)
|
alice1MergedBundle2, err := s.services[aliceUser].encryptionServices[0].CreateBundle(aliceKey)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
// We get back a bundle with all the installations
|
// We get back a bundle with all the installations
|
||||||
|
@ -124,7 +118,7 @@ func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundle() {
|
||||||
s.Require().NotNil(alice1MergedBundle2.GetSignedPreKeys()["alice2"])
|
s.Require().NotNil(alice1MergedBundle2.GetSignedPreKeys()["alice2"])
|
||||||
s.Require().NotNil(alice1MergedBundle2.GetSignedPreKeys()["alice3"])
|
s.Require().NotNil(alice1MergedBundle2.GetSignedPreKeys()["alice3"])
|
||||||
|
|
||||||
response, err = s.alice1.ProcessPublicBundle(aliceKey, alice1MergedBundle2)
|
response, err = s.services[aliceUser].encryptionServices[0].ProcessPublicBundle(aliceKey, alice1MergedBundle2)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
sort.Slice(response, func(i, j int) bool {
|
sort.Slice(response, func(i, j int) bool {
|
||||||
return response[i][1] < response[j][1]
|
return response[i][1] < response[j][1]
|
||||||
|
@ -135,10 +129,10 @@ func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundle() {
|
||||||
s.Require().Equal(IdentityAndIDPair{alice2Identity, "alice3"}, response[1])
|
s.Require().Equal(IdentityAndIDPair{alice2Identity, "alice3"}, response[1])
|
||||||
|
|
||||||
// We disable the installations
|
// We disable the installations
|
||||||
err = s.alice1.DisableInstallation(&aliceKey.PublicKey, "alice2")
|
err = s.services[aliceUser].encryptionServices[0].DisableInstallation(&aliceKey.PublicKey, "alice2")
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
alice1MergedBundle3, err := s.alice1.CreateBundle(aliceKey)
|
alice1MergedBundle3, err := s.services[aliceUser].encryptionServices[0].CreateBundle(aliceKey)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
// We get back a bundle with all the installations
|
// We get back a bundle with all the installations
|
||||||
|
@ -152,25 +146,111 @@ func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundleOutOfOrder()
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
// Alice1 creates a bundle
|
// Alice1 creates a bundle
|
||||||
alice1Bundle, err := s.alice1.CreateBundle(aliceKey)
|
alice1Bundle, err := s.services[aliceUser].encryptionServices[0].CreateBundle(aliceKey)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
// Alice2 Receives the bundle
|
// Alice2 Receives the bundle
|
||||||
_, err = s.alice2.ProcessPublicBundle(aliceKey, alice1Bundle)
|
_, err = s.services[aliceUser].encryptionServices[1].ProcessPublicBundle(aliceKey, alice1Bundle)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
// Alice2 Creates a Bundle
|
// Alice2 Creates a Bundle
|
||||||
_, err = s.alice2.CreateBundle(aliceKey)
|
_, err = s.services[aliceUser].encryptionServices[1].CreateBundle(aliceKey)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
// We enable the installation
|
// We enable the installation
|
||||||
err = s.alice2.EnableInstallation(&aliceKey.PublicKey, "alice1")
|
err = s.services[aliceUser].encryptionServices[1].EnableInstallation(&aliceKey.PublicKey, "alice1")
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
// It should contain both bundles
|
// It should contain both bundles
|
||||||
alice2MergedBundle1, err := s.alice2.CreateBundle(aliceKey)
|
alice2MergedBundle1, err := s.services[aliceUser].encryptionServices[1].CreateBundle(aliceKey)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
s.Require().NotNil(alice2MergedBundle1.GetSignedPreKeys()["alice1"])
|
s.Require().NotNil(alice2MergedBundle1.GetSignedPreKeys()["alice1"])
|
||||||
s.Require().NotNil(alice2MergedBundle1.GetSignedPreKeys()["alice2"])
|
s.Require().NotNil(alice2MergedBundle1.GetSignedPreKeys()["alice2"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func pairDevices(s *serviceAndKey, target int) error {
|
||||||
|
device := s.encryptionServices[target]
|
||||||
|
for i := 0; i < len(s.encryptionServices); i++ {
|
||||||
|
b, err := s.encryptionServices[i].CreateBundle(s.key)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = device.ProcessPublicBundle(s.key, b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = device.EnableInstallation(&s.key.PublicKey, s.encryptionServices[i].config.InstallationID)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *EncryptionServiceMultiDeviceSuite) TestMaxDevices() {
|
||||||
|
err := pairDevices(s.services[aliceUser], 0)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
alice1 := s.services[aliceUser].encryptionServices[0]
|
||||||
|
bob1 := s.services[bobUser].encryptionServices[0]
|
||||||
|
aliceKey := s.services[aliceUser].key
|
||||||
|
bobKey := s.services[bobUser].key
|
||||||
|
|
||||||
|
// Check bundle is ok
|
||||||
|
// No installation is enabled
|
||||||
|
aliceBundle, err := alice1.CreateBundle(aliceKey)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Check all installations are correctly working, and that the oldest device is not there
|
||||||
|
preKeys := aliceBundle.GetSignedPreKeys()
|
||||||
|
s.Require().Equal(3, len(preKeys))
|
||||||
|
s.Require().NotNil(preKeys["alice1"])
|
||||||
|
// alice2 being the oldest device is rotated out, as we reached the maximum
|
||||||
|
s.Require().Nil(preKeys["alice2"])
|
||||||
|
s.Require().NotNil(preKeys["alice3"])
|
||||||
|
s.Require().NotNil(preKeys["alice4"])
|
||||||
|
|
||||||
|
// We propagate this to bob
|
||||||
|
_, err = bob1.ProcessPublicBundle(bobKey, aliceBundle)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Bob sends a message to alice
|
||||||
|
payload, err := bob1.EncryptPayload(&aliceKey.PublicKey, bobKey, []byte("test"))
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(3, len(payload))
|
||||||
|
s.Require().NotNil(payload["alice1"])
|
||||||
|
s.Require().NotNil(payload["alice3"])
|
||||||
|
s.Require().NotNil(payload["alice4"])
|
||||||
|
|
||||||
|
// We disable the last installation
|
||||||
|
err = s.services[aliceUser].encryptionServices[0].DisableInstallation(&aliceKey.PublicKey, "alice4")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// We check the bundle is updated
|
||||||
|
aliceBundle, err = alice1.CreateBundle(aliceKey)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Check all installations are there
|
||||||
|
preKeys = aliceBundle.GetSignedPreKeys()
|
||||||
|
s.Require().Equal(3, len(preKeys))
|
||||||
|
s.Require().NotNil(preKeys["alice1"])
|
||||||
|
s.Require().NotNil(preKeys["alice2"])
|
||||||
|
s.Require().NotNil(preKeys["alice3"])
|
||||||
|
// alice4 is disabled at this point, alice2 is back in
|
||||||
|
s.Require().Nil(preKeys["alice4"])
|
||||||
|
|
||||||
|
// We propagate this to bob
|
||||||
|
_, err = bob1.ProcessPublicBundle(bobKey, aliceBundle)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Bob sends a message to alice
|
||||||
|
payload, err = bob1.EncryptPayload(&aliceKey.PublicKey, bobKey, []byte("test"))
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(3, len(payload))
|
||||||
|
s.Require().NotNil(payload["alice1"])
|
||||||
|
s.Require().NotNil(payload["alice2"])
|
||||||
|
s.Require().NotNil(payload["alice3"])
|
||||||
|
}
|
||||||
|
|
|
@ -47,8 +47,8 @@ func (s *EncryptionServiceTestSuite) initDatabases() {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.alice = NewEncryptionService(alicePersistence, aliceInstallationID)
|
s.alice = NewEncryptionService(alicePersistence, DefaultEncryptionServiceConfig(aliceInstallationID))
|
||||||
s.bob = NewEncryptionService(bobPersistence, bobInstallationID)
|
s.bob = NewEncryptionService(bobPersistence, DefaultEncryptionServiceConfig(bobInstallationID))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *EncryptionServiceTestSuite) SetupTest() {
|
func (s *EncryptionServiceTestSuite) SetupTest() {
|
||||||
|
@ -346,7 +346,7 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() {
|
||||||
|
|
||||||
// Bob sends a message
|
// Bob sends a message
|
||||||
|
|
||||||
for i := 0; i < maxSkip; i++ {
|
for i := 0; i < s.alice.config.MaxSkip; i++ {
|
||||||
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
}
|
}
|
||||||
|
@ -401,7 +401,7 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeysError() {
|
||||||
|
|
||||||
// Bob sends a message
|
// Bob sends a message
|
||||||
|
|
||||||
for i := 0; i < maxSkip+1; i++ {
|
for i := 0; i < s.alice.config.MaxSkip+1; i++ {
|
||||||
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
}
|
}
|
||||||
|
@ -442,7 +442,7 @@ func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() {
|
||||||
|
|
||||||
// We create just enough messages so that the first key should be deleted
|
// We create just enough messages so that the first key should be deleted
|
||||||
|
|
||||||
nMessages := maxMessageKeysPerSession + maxMessageKeysPerSession/maxSkip + 2
|
nMessages := s.alice.config.MaxMessageKeysPerSession + s.alice.config.MaxMessageKeysPerSession/s.alice.config.MaxSkip + 2
|
||||||
messages := make([]map[string]*DirectMessageProtocol, nMessages)
|
messages := make([]map[string]*DirectMessageProtocol, nMessages)
|
||||||
for i := 0; i < nMessages; i++ {
|
for i := 0; i < nMessages; i++ {
|
||||||
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||||
|
@ -451,7 +451,7 @@ func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() {
|
||||||
messages[i] = m
|
messages[i] = m
|
||||||
|
|
||||||
// We decrypt some messages otherwise we hit maxSkip limit
|
// We decrypt some messages otherwise we hit maxSkip limit
|
||||||
if i%maxSkip == 0 {
|
if i%s.alice.config.MaxSkip == 0 {
|
||||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
|
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
}
|
}
|
||||||
|
@ -499,8 +499,8 @@ func (s *EncryptionServiceTestSuite) TestMaxKeep() {
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
// We decrypt all messages but 1 & 2
|
// We decrypt all messages but 1 & 2
|
||||||
messages := make([]map[string]*DirectMessageProtocol, maxKeep)
|
messages := make([]map[string]*DirectMessageProtocol, s.alice.config.MaxKeep)
|
||||||
for i := 0; i < maxKeep; i++ {
|
for i := 0; i < s.alice.config.MaxKeep; i++ {
|
||||||
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||||
messages[i] = m
|
messages[i] = m
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
|
@ -50,7 +50,7 @@ type PersistenceService interface {
|
||||||
RatchetInfoConfirmed([]byte, []byte, string) error
|
RatchetInfoConfirmed([]byte, []byte, string) error
|
||||||
|
|
||||||
// GetActiveInstallations returns the active installations for a given identity.
|
// GetActiveInstallations returns the active installations for a given identity.
|
||||||
GetActiveInstallations(maxInstallations uint, identity []byte) ([]string, error)
|
GetActiveInstallations(maxInstallations int, identity []byte) ([]string, error)
|
||||||
// AddInstallations adds the installations for a given identity.
|
// AddInstallations adds the installations for a given identity.
|
||||||
AddInstallations(identity []byte, timestamp int64, installationIDs []string, enabled bool) error
|
AddInstallations(identity []byte, timestamp int64, installationIDs []string, enabled bool) error
|
||||||
// EnableInstallation enables the installation.
|
// EnableInstallation enables the installation.
|
||||||
|
|
|
@ -51,7 +51,7 @@ func (p *ProtocolService) addBundleAndMarshal(myIdentityKey *ecdsa.PrivateKey, m
|
||||||
func (p *ProtocolService) BuildPublicMessage(myIdentityKey *ecdsa.PrivateKey, payload []byte) ([]byte, error) {
|
func (p *ProtocolService) BuildPublicMessage(myIdentityKey *ecdsa.PrivateKey, payload []byte) ([]byte, error) {
|
||||||
// Build message not encrypted
|
// Build message not encrypted
|
||||||
protocolMessage := &ProtocolMessage{
|
protocolMessage := &ProtocolMessage{
|
||||||
InstallationId: p.encryption.installationID,
|
InstallationId: p.encryption.config.InstallationID,
|
||||||
PublicMessage: payload,
|
PublicMessage: payload,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ func (p *ProtocolService) BuildDirectMessage(myIdentityKey *ecdsa.PrivateKey, pa
|
||||||
|
|
||||||
// Build message
|
// Build message
|
||||||
protocolMessage := &ProtocolMessage{
|
protocolMessage := &ProtocolMessage{
|
||||||
InstallationId: p.encryption.installationID,
|
InstallationId: p.encryption.config.InstallationID,
|
||||||
DirectMessage: encryptionResponse,
|
DirectMessage: encryptionResponse,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ func (p *ProtocolService) BuildPairingMessage(myIdentityKey *ecdsa.PrivateKey, p
|
||||||
|
|
||||||
// Build message
|
// Build message
|
||||||
protocolMessage := &ProtocolMessage{
|
protocolMessage := &ProtocolMessage{
|
||||||
InstallationId: p.encryption.installationID,
|
InstallationId: p.encryption.config.InstallationID,
|
||||||
DirectMessage: encryptionResponse,
|
DirectMessage: encryptionResponse,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,8 +38,8 @@ func (s *ProtocolServiceTestSuite) SetupTest() {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.alice = NewProtocolService(NewEncryptionService(alicePersistence, "1"))
|
s.alice = NewProtocolService(NewEncryptionService(alicePersistence, DefaultEncryptionServiceConfig("1")))
|
||||||
s.bob = NewProtocolService(NewEncryptionService(bobPersistence, "2"))
|
s.bob = NewProtocolService(NewEncryptionService(bobPersistence, DefaultEncryptionServiceConfig("2")))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProtocolServiceTestSuite) TestBuildDirectMessage() {
|
func (s *ProtocolServiceTestSuite) TestBuildDirectMessage() {
|
||||||
|
|
|
@ -709,7 +709,7 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveInstallations returns the active installations for a given identity
|
// GetActiveInstallations returns the active installations for a given identity
|
||||||
func (s *SQLLitePersistence) GetActiveInstallations(maxInstallations uint, identity []byte) ([]string, error) {
|
func (s *SQLLitePersistence) GetActiveInstallations(maxInstallations int, identity []byte) ([]string, error) {
|
||||||
stmt, err := s.db.Prepare("SELECT installation_id FROM installations WHERE enabled = 1 AND identity = ? ORDER BY timestamp DESC LIMIT ?")
|
stmt, err := s.db.Prepare("SELECT installation_id FROM installations WHERE enabled = 1 AND identity = ? ORDER BY timestamp DESC LIMIT ?")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -102,7 +102,7 @@ func (s *Service) InitProtocol(address string, password string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.protocol = chat.NewProtocolService(chat.NewEncryptionService(persistence, s.installationID))
|
s.protocol = chat.NewProtocolService(chat.NewEncryptionService(persistence, chat.DefaultEncryptionServiceConfig(s.installationID)))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue