Add versioning & tests, migrate db files (#1293)
We are preparing for the release of this to general public, so a few things have been added: 1) Add versioning for bundles, and make refresh interval configurable 2) Move files to installationID so no metadata is leaked 3) Re-key using user password db
This commit is contained in:
parent
e60dbe3c1b
commit
38bb4d8ef3
|
@ -40,6 +40,8 @@ type EncryptionServiceConfig struct {
|
|||
MaxKeep int
|
||||
// How many keys do we store in total per session.
|
||||
MaxMessageKeysPerSession int
|
||||
// How long before we refresh the interval in milliseconds
|
||||
BundleRefreshInterval int64
|
||||
}
|
||||
|
||||
type IdentityAndIDPair [2]string
|
||||
|
@ -51,6 +53,7 @@ func DefaultEncryptionServiceConfig(installationID string) EncryptionServiceConf
|
|||
MaxSkip: 1000,
|
||||
MaxKeep: 3000,
|
||||
MaxMessageKeysPerSession: 2000,
|
||||
BundleRefreshInterval: 14 * 24 * 60 * 60 * 1000,
|
||||
InstallationID: installationID,
|
||||
}
|
||||
}
|
||||
|
@ -107,7 +110,7 @@ func (s *EncryptionService) CreateBundle(privateKey *ecdsa.PrivateKey) (*Bundle,
|
|||
}
|
||||
|
||||
// If the bundle has expired we create a new one
|
||||
if bundleContainer != nil && bundleContainer.GetBundle().Timestamp < time.Now().AddDate(0, 0, -14).UnixNano() {
|
||||
if bundleContainer != nil && bundleContainer.GetBundle().Timestamp < time.Now().Add(-1*time.Duration(s.config.BundleRefreshInterval)*time.Millisecond).UnixNano() {
|
||||
// Mark sessions has expired
|
||||
if err := s.persistence.MarkBundleExpired(bundleContainer.GetBundle().GetIdentity()); err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"crypto/ecdsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"os"
|
||||
"reflect"
|
||||
|
@ -25,16 +26,33 @@ func TestEncryptionServiceTestSuite(t *testing.T) {
|
|||
|
||||
type EncryptionServiceTestSuite struct {
|
||||
suite.Suite
|
||||
alice *EncryptionService
|
||||
bob *EncryptionService
|
||||
alice *EncryptionService
|
||||
bob *EncryptionService
|
||||
aliceDBPath string
|
||||
bobDBPath string
|
||||
}
|
||||
|
||||
func (s *EncryptionServiceTestSuite) initDatabases() {
|
||||
func (s *EncryptionServiceTestSuite) initDatabases(baseConfig *EncryptionServiceConfig) {
|
||||
|
||||
aliceDBFile, err := ioutil.TempFile(os.TempDir(), "alice")
|
||||
s.Require().NoError(err)
|
||||
aliceDBPath := aliceDBFile.Name()
|
||||
|
||||
bobDBFile, err := ioutil.TempFile(os.TempDir(), "bob")
|
||||
s.Require().NoError(err)
|
||||
bobDBPath := bobDBFile.Name()
|
||||
|
||||
s.aliceDBPath = aliceDBPath
|
||||
s.bobDBPath = bobDBPath
|
||||
|
||||
if baseConfig == nil {
|
||||
config := DefaultEncryptionServiceConfig(aliceInstallationID)
|
||||
baseConfig = &config
|
||||
}
|
||||
|
||||
const (
|
||||
aliceDBPath = "/tmp/alice.db"
|
||||
aliceDBKey = "alice"
|
||||
bobDBPath = "/tmp/bob.db"
|
||||
bobDBKey = "bob"
|
||||
aliceDBKey = "alice"
|
||||
bobDBKey = "bob"
|
||||
)
|
||||
|
||||
alicePersistence, err := NewSQLLitePersistence(aliceDBPath, aliceDBKey)
|
||||
|
@ -47,14 +65,20 @@ func (s *EncryptionServiceTestSuite) initDatabases() {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
s.alice = NewEncryptionService(alicePersistence, DefaultEncryptionServiceConfig(aliceInstallationID))
|
||||
s.bob = NewEncryptionService(bobPersistence, DefaultEncryptionServiceConfig(bobInstallationID))
|
||||
baseConfig.InstallationID = aliceInstallationID
|
||||
s.alice = NewEncryptionService(alicePersistence, *baseConfig)
|
||||
|
||||
baseConfig.InstallationID = bobInstallationID
|
||||
s.bob = NewEncryptionService(bobPersistence, *baseConfig)
|
||||
}
|
||||
|
||||
func (s *EncryptionServiceTestSuite) SetupTest() {
|
||||
os.Remove("/tmp/alice.db")
|
||||
os.Remove("/tmp/bob.db")
|
||||
s.initDatabases()
|
||||
s.initDatabases(nil)
|
||||
}
|
||||
|
||||
func (s *EncryptionServiceTestSuite) TearDownTest() {
|
||||
os.Remove(s.aliceDBPath)
|
||||
os.Remove(s.bobDBPath)
|
||||
}
|
||||
|
||||
func (s *EncryptionServiceTestSuite) TestCreateBundle() {
|
||||
|
@ -749,6 +773,12 @@ func (s *EncryptionServiceTestSuite) TestBundleNotExisting() {
|
|||
// A new bundle has been received
|
||||
func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
|
||||
|
||||
config := DefaultEncryptionServiceConfig("none")
|
||||
// Set up refresh interval to "always"
|
||||
config.BundleRefreshInterval = 1000
|
||||
|
||||
s.initDatabases(&config)
|
||||
|
||||
bobKey, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
|
@ -756,23 +786,20 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
// Create bundles
|
||||
bobBundle1, err := NewBundleContainer(bobKey, bobInstallationID)
|
||||
bobBundle1, err := s.bob.CreateBundle(bobKey)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(uint32(1), bobBundle1.GetSignedPreKeys()[bobInstallationID].GetVersion())
|
||||
|
||||
err = SignBundle(bobKey, bobBundle1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
bobBundle2, err := NewBundleContainer(bobKey, bobInstallationID)
|
||||
s.Require().NoError(err)
|
||||
// We set the version
|
||||
|
||||
bobBundle2.GetBundle().GetSignedPreKeys()[bobInstallationID].Version = 1
|
||||
|
||||
err = SignBundle(bobKey, bobBundle2)
|
||||
// Sleep the required time so that bundle is refreshed
|
||||
time.Sleep(time.Duration(config.BundleRefreshInterval) * time.Millisecond)
|
||||
|
||||
// Create bundles
|
||||
bobBundle2, err := s.bob.CreateBundle(bobKey)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(uint32(2), bobBundle2.GetSignedPreKeys()[bobInstallationID].GetVersion())
|
||||
|
||||
// We add the first bob bundle
|
||||
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle1.GetBundle())
|
||||
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Alice sends a message
|
||||
|
@ -786,10 +813,10 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
|
|||
|
||||
x3dhHeader1 := installationResponse1.GetX3DHHeader()
|
||||
s.NotNil(x3dhHeader1)
|
||||
s.Equal(bobBundle1.GetBundle().GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader1.GetId())
|
||||
s.Equal(bobBundle1.GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader1.GetId())
|
||||
|
||||
// We add the second bob bundle
|
||||
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle2.GetBundle())
|
||||
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Alice sends a message
|
||||
|
@ -803,6 +830,6 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
|
|||
|
||||
x3dhHeader2 := installationResponse2.GetX3DHHeader()
|
||||
s.NotNil(x3dhHeader2)
|
||||
s.Equal(bobBundle2.GetBundle().GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader2.GetId())
|
||||
s.Equal(bobBundle2.GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader2.GetId())
|
||||
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ package chat
|
|||
import (
|
||||
"crypto/ecdsa"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
|
@ -51,6 +53,64 @@ func NewSQLLitePersistence(path string, key string) (*SQLLitePersistence, error)
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func MigrateDBFile(oldPath string, newPath string, key string) error {
|
||||
_, err := os.Stat(oldPath)
|
||||
|
||||
// No files, nothing to do
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Any other error, throws
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Rename(oldPath, newPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Migrate dev/nightly builds which used ON as a key for debugging
|
||||
db, err := openDB(newPath, "ON")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyString := fmt.Sprintf("PRAGMA rekey=%s", key)
|
||||
|
||||
if _, err = db.Exec(keyString); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func openDB(path string, key string) (*sql.DB, error) {
|
||||
db, err := sql.Open("sqlite3", path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyString := fmt.Sprintf("PRAGMA key=%s", key)
|
||||
|
||||
// Disable concurrent access as not supported by the driver
|
||||
db.SetMaxOpenConns(1)
|
||||
|
||||
if _, err = db.Exec("PRAGMA foreign_keys=ON"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err = db.Exec(keyString); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err = db.Exec("PRAGMA cypher_page_size=4096"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// NewSQLLiteKeysStorage creates a new SQLLiteKeysStorage instance associated with the specified database
|
||||
func NewSQLLiteKeysStorage(db *sql.DB) *SQLLiteKeysStorage {
|
||||
return &SQLLiteKeysStorage{
|
||||
|
@ -77,26 +137,11 @@ func (s *SQLLitePersistence) GetSessionStorage() dr.SessionStorage {
|
|||
|
||||
// Open opens a file at the specified path
|
||||
func (s *SQLLitePersistence) Open(path string, key string) error {
|
||||
db, err := sql.Open("sqlite3", path)
|
||||
db, err := openDB(path, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Disable concurrent access as not supported by the driver
|
||||
db.SetMaxOpenConns(1)
|
||||
|
||||
if _, err = db.Exec("PRAGMA foreign_keys=ON"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = db.Exec("PRAGMA key=ON"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = db.Exec("PRAGMA cypher_page_size=4096"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.db = db
|
||||
|
||||
return s.setup()
|
||||
|
@ -111,7 +156,11 @@ func (s *SQLLitePersistence) AddPrivateBundle(bc *BundleContainer) error {
|
|||
|
||||
for installationID, signedPreKey := range bc.GetBundle().GetSignedPreKeys() {
|
||||
var version uint32
|
||||
stmt, err := tx.Prepare("SELECT version FROM bundles WHERE installation_id = ? AND identity = ? ORDER BY version DESC LIMIT 1")
|
||||
stmt, err := tx.Prepare(`SELECT version
|
||||
FROM bundles
|
||||
WHERE installation_id = ? AND identity = ?
|
||||
ORDER BY version DESC
|
||||
LIMIT 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -123,7 +172,8 @@ func (s *SQLLitePersistence) AddPrivateBundle(bc *BundleContainer) error {
|
|||
return err
|
||||
}
|
||||
|
||||
stmt, err = tx.Prepare("INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, version, timestamp) VALUES(?, ?, ?, ?, ?, ?)")
|
||||
stmt, err = tx.Prepare(`INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, version, timestamp)
|
||||
VALUES(?, ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -162,7 +212,8 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
|
|||
for installationID, signedPreKeyContainer := range b.GetSignedPreKeys() {
|
||||
signedPreKey := signedPreKeyContainer.GetSignedPreKey()
|
||||
version := signedPreKeyContainer.GetVersion()
|
||||
insertStmt, err := tx.Prepare("INSERT INTO bundles(identity, signed_pre_key, installation_id, version, timestamp) VALUES( ?, ?, ?, ?, ?)")
|
||||
insertStmt, err := tx.Prepare(`INSERT INTO bundles(identity, signed_pre_key, installation_id, version, timestamp)
|
||||
VALUES( ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -180,7 +231,9 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
|
|||
return err
|
||||
}
|
||||
// Mark old bundles as expired
|
||||
updateStmt, err := tx.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND installation_id = ? AND version < ?")
|
||||
updateStmt, err := tx.Prepare(`UPDATE bundles
|
||||
SET expired = 1
|
||||
WHERE identity = ? AND installation_id = ? AND version < ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -205,7 +258,9 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
|
|||
func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installationIDs []string) (*BundleContainer, error) {
|
||||
|
||||
/* #nosec */
|
||||
statement := "SELECT identity, private_key, signed_pre_key, installation_id, timestamp FROM bundles WHERE expired = 0 AND identity = ? AND installation_id IN (?" + strings.Repeat(",?", len(installationIDs)-1) + ")"
|
||||
statement := `SELECT identity, private_key, signed_pre_key, installation_id, timestamp, version
|
||||
FROM bundles
|
||||
WHERE expired = 0 AND identity = ? AND installation_id IN (?` + strings.Repeat(",?", len(installationIDs)-1) + ")"
|
||||
stmt, err := s.db.Prepare(statement)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -215,6 +270,7 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat
|
|||
var timestamp int64
|
||||
var identity []byte
|
||||
var privateKey []byte
|
||||
var version uint32
|
||||
|
||||
args := make([]interface{}, len(installationIDs)+1)
|
||||
args[0] = myIdentityKey
|
||||
|
@ -249,6 +305,7 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat
|
|||
&signedPreKey,
|
||||
&installationID,
|
||||
×tamp,
|
||||
&version,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -258,7 +315,7 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat
|
|||
bundle.Timestamp = timestamp
|
||||
}
|
||||
|
||||
bundle.SignedPreKeys[installationID] = &SignedPreKey{SignedPreKey: signedPreKey}
|
||||
bundle.SignedPreKeys[installationID] = &SignedPreKey{SignedPreKey: signedPreKey, Version: version}
|
||||
bundle.Identity = identity
|
||||
}
|
||||
|
||||
|
@ -273,7 +330,9 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat
|
|||
|
||||
// GetPrivateKeyBundle retrieves a private key for a bundle from the database
|
||||
func (s *SQLLitePersistence) GetPrivateKeyBundle(bundleID []byte) ([]byte, error) {
|
||||
stmt, err := s.db.Prepare("SELECT private_key FROM bundles WHERE expired = 0 AND signed_pre_key = ? LIMIT 1")
|
||||
stmt, err := s.db.Prepare(`SELECT private_key
|
||||
FROM bundles
|
||||
WHERE expired = 0 AND signed_pre_key = ? LIMIT 1`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -294,7 +353,9 @@ func (s *SQLLitePersistence) GetPrivateKeyBundle(bundleID []byte) ([]byte, error
|
|||
|
||||
// MarkBundleExpired expires any private bundle for a given identity
|
||||
func (s *SQLLitePersistence) MarkBundleExpired(identity []byte) error {
|
||||
stmt, err := s.db.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND private_key IS NOT NULL")
|
||||
stmt, err := s.db.Prepare(`UPDATE bundles
|
||||
SET expired = 1
|
||||
WHERE identity = ? AND private_key IS NOT NULL`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -315,7 +376,10 @@ func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, install
|
|||
identity := crypto.CompressPubkey(publicKey)
|
||||
|
||||
/* #nosec */
|
||||
statement := "SELECT signed_pre_key,installation_id, version FROM bundles WHERE expired = 0 AND identity = ? AND installation_id IN (?" + strings.Repeat(",?", len(installationIDs)-1) + ") ORDER BY version DESC"
|
||||
statement := `SELECT signed_pre_key,installation_id, version
|
||||
FROM bundles
|
||||
WHERE expired = 0 AND identity = ? AND installation_id IN (?` + strings.Repeat(",?", len(installationIDs)-1) + `)
|
||||
ORDER BY version DESC`
|
||||
stmt, err := s.db.Prepare(statement)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -373,7 +437,8 @@ func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, install
|
|||
|
||||
// AddRatchetInfo persists the specified ratchet info into the database
|
||||
func (s *SQLLitePersistence) AddRatchetInfo(key []byte, identity []byte, bundleID []byte, ephemeralKey []byte, installationID string) error {
|
||||
stmt, err := s.db.Prepare("INSERT INTO ratchet_info_v2(symmetric_key, identity, bundle_id, ephemeral_key, installation_id) VALUES(?, ?, ?, ?, ?)")
|
||||
stmt, err := s.db.Prepare(`INSERT INTO ratchet_info_v2(symmetric_key, identity, bundle_id, ephemeral_key, installation_id)
|
||||
VALUES(?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -392,7 +457,10 @@ func (s *SQLLitePersistence) AddRatchetInfo(key []byte, identity []byte, bundleI
|
|||
|
||||
// GetRatchetInfo retrieves the existing RatchetInfo for a specified bundle ID and interlocutor public key from the database
|
||||
func (s *SQLLitePersistence) GetRatchetInfo(bundleID []byte, theirIdentity []byte, installationID string) (*RatchetInfo, error) {
|
||||
stmt, err := s.db.Prepare("SELECT ratchet_info_v2.identity, ratchet_info_v2.symmetric_key, bundles.private_key, bundles.signed_pre_key, ratchet_info_v2.ephemeral_key, ratchet_info_v2.installation_id FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key WHERE ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? AND bundle_id = ? LIMIT 1")
|
||||
stmt, err := s.db.Prepare(`SELECT ratchet_info_v2.identity, ratchet_info_v2.symmetric_key, bundles.private_key, bundles.signed_pre_key, ratchet_info_v2.ephemeral_key, ratchet_info_v2.installation_id
|
||||
FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key
|
||||
WHERE ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? AND bundle_id = ?
|
||||
LIMIT 1`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -423,7 +491,10 @@ func (s *SQLLitePersistence) GetRatchetInfo(bundleID []byte, theirIdentity []byt
|
|||
|
||||
// GetAnyRatchetInfo retrieves any existing RatchetInfo for a specified interlocutor public key from the database
|
||||
func (s *SQLLitePersistence) GetAnyRatchetInfo(identity []byte, installationID string) (*RatchetInfo, error) {
|
||||
stmt, err := s.db.Prepare("SELECT symmetric_key, bundles.private_key, signed_pre_key, bundle_id, ephemeral_key FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key WHERE expired = 0 AND ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? LIMIT 1")
|
||||
stmt, err := s.db.Prepare(`SELECT symmetric_key, bundles.private_key, signed_pre_key, bundle_id, ephemeral_key
|
||||
FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key
|
||||
WHERE expired = 0 AND ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ?
|
||||
LIMIT 1`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -455,7 +526,9 @@ func (s *SQLLitePersistence) GetAnyRatchetInfo(identity []byte, installationID s
|
|||
// RatchetInfoConfirmed clears the ephemeral key in the RatchetInfo
|
||||
// associated with the specified bundle ID and interlocutor identity public key
|
||||
func (s *SQLLitePersistence) RatchetInfoConfirmed(bundleID []byte, theirIdentity []byte, installationID string) error {
|
||||
stmt, err := s.db.Prepare("UPDATE ratchet_info_v2 SET ephemeral_key = NULL WHERE identity = ? AND bundle_id = ? AND installation_id = ?")
|
||||
stmt, err := s.db.Prepare(`UPDATE ratchet_info_v2
|
||||
SET ephemeral_key = NULL
|
||||
WHERE identity = ? AND bundle_id = ? AND installation_id = ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -474,7 +547,10 @@ func (s *SQLLitePersistence) RatchetInfoConfirmed(bundleID []byte, theirIdentity
|
|||
func (s *SQLLiteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, error) {
|
||||
var keyBytes []byte
|
||||
var key [32]byte
|
||||
stmt, err := s.db.Prepare("SELECT message_key FROM keys WHERE public_key = ? AND msg_num = ? LIMIT 1")
|
||||
stmt, err := s.db.Prepare(`SELECT message_key
|
||||
FROM keys
|
||||
WHERE public_key = ? AND msg_num = ?
|
||||
LIMIT 1`)
|
||||
|
||||
if err != nil {
|
||||
return key, false, err
|
||||
|
@ -495,7 +571,8 @@ func (s *SQLLiteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, erro
|
|||
|
||||
// Put stores a key with the specified public key, message number and message key
|
||||
func (s *SQLLiteKeysStorage) Put(sessionID []byte, pubKey dr.Key, msgNum uint, mk dr.Key, seqNum uint) error {
|
||||
stmt, err := s.db.Prepare("insert into keys(session_id, public_key, msg_num, message_key, seq_num) values(?, ?, ?, ?, ?)")
|
||||
stmt, err := s.db.Prepare(`INSERT INTO keys(session_id, public_key, msg_num, message_key, seq_num)
|
||||
VALUES(?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -514,7 +591,8 @@ func (s *SQLLiteKeysStorage) Put(sessionID []byte, pubKey dr.Key, msgNum uint, m
|
|||
|
||||
// DeleteOldMks caps remove any key < seq_num, included
|
||||
func (s *SQLLiteKeysStorage) DeleteOldMks(sessionID []byte, deleteUntil uint) error {
|
||||
stmt, err := s.db.Prepare("DELETE FROM keys WHERE session_id = ? AND seq_num <= ?")
|
||||
stmt, err := s.db.Prepare(`DELETE FROM keys
|
||||
WHERE session_id = ? AND seq_num <= ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -530,7 +608,8 @@ func (s *SQLLiteKeysStorage) DeleteOldMks(sessionID []byte, deleteUntil uint) er
|
|||
|
||||
// TruncateMks caps the number of keys to maxKeysPerSession deleting them in FIFO fashion
|
||||
func (s *SQLLiteKeysStorage) TruncateMks(sessionID []byte, maxKeysPerSession int) error {
|
||||
stmt, err := s.db.Prepare("DELETE FROM keys WHERE rowid IN (SELECT rowid FROM keys WHERE session_id = ? ORDER BY seq_num DESC LIMIT ? OFFSET ?)")
|
||||
stmt, err := s.db.Prepare(`DELETE FROM keys
|
||||
WHERE rowid IN (SELECT rowid FROM keys WHERE session_id = ? ORDER BY seq_num DESC LIMIT ? OFFSET ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -548,7 +627,8 @@ func (s *SQLLiteKeysStorage) TruncateMks(sessionID []byte, maxKeysPerSession int
|
|||
|
||||
// DeleteMk deletes the key with the specified public key and message key
|
||||
func (s *SQLLiteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error {
|
||||
stmt, err := s.db.Prepare("DELETE FROM keys WHERE public_key = ? AND msg_num = ?")
|
||||
stmt, err := s.db.Prepare(`DELETE FROM keys
|
||||
WHERE public_key = ? AND msg_num = ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -564,7 +644,9 @@ func (s *SQLLiteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error {
|
|||
|
||||
// Count returns the count of keys with the specified public key
|
||||
func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) {
|
||||
stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys WHERE public_key = ?")
|
||||
stmt, err := s.db.Prepare(`SELECT COUNT(1)
|
||||
FROM keys
|
||||
WHERE public_key = ?`)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -581,7 +663,8 @@ func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) {
|
|||
|
||||
// CountAll returns the count of keys with the specified public key
|
||||
func (s *SQLLiteKeysStorage) CountAll() (uint, error) {
|
||||
stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys")
|
||||
stmt, err := s.db.Prepare(`SELECT COUNT(1)
|
||||
FROM keys`)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -619,7 +702,8 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
|
|||
recvChainKey := state.RecvCh.CK[:]
|
||||
recvChainN := state.RecvCh.N
|
||||
|
||||
stmt, err := s.db.Prepare("insert into sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
|
||||
stmt, err := s.db.Prepare(`INSERT INTO sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count)
|
||||
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -645,7 +729,9 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
|
|||
|
||||
// Load retrieves the double ratchet state for a given ID
|
||||
func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
|
||||
stmt, err := s.db.Prepare("SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count FROM sessions WHERE id = ?")
|
||||
stmt, err := s.db.Prepare(`SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count
|
||||
FROM sessions
|
||||
WHERE id = ?`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -710,7 +796,11 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
|
|||
|
||||
// GetActiveInstallations returns the active installations for a given identity
|
||||
func (s *SQLLitePersistence) GetActiveInstallations(maxInstallations int, identity []byte) ([]string, error) {
|
||||
stmt, err := s.db.Prepare("SELECT installation_id FROM installations WHERE enabled = 1 AND identity = ? ORDER BY timestamp DESC LIMIT ?")
|
||||
stmt, err := s.db.Prepare(`SELECT installation_id
|
||||
FROM installations
|
||||
WHERE enabled = 1 AND identity = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -744,7 +834,10 @@ func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64,
|
|||
}
|
||||
|
||||
for _, installationID := range installationIDs {
|
||||
stmt, err := tx.Prepare("SELECT enabled FROM installations WHERE identity = ? AND installation_id = ? LIMIT 1")
|
||||
stmt, err := tx.Prepare(`SELECT enabled
|
||||
FROM installations
|
||||
WHERE identity = ? AND installation_id = ?
|
||||
LIMIT 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -759,7 +852,9 @@ func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64,
|
|||
|
||||
// We update timestamp if present without changing enabled
|
||||
if err != sql.ErrNoRows {
|
||||
stmt, err = tx.Prepare("UPDATE installations SET timestamp = ?, enabled = ? WHERE identity = ? AND installation_id = ?")
|
||||
stmt, err = tx.Prepare(`UPDATE installations
|
||||
SET timestamp = ?, enabled = ?
|
||||
WHERE identity = ? AND installation_id = ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -776,7 +871,8 @@ func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64,
|
|||
defer stmt.Close()
|
||||
|
||||
} else {
|
||||
stmt, err = tx.Prepare("INSERT INTO installations(identity, installation_id, timestamp, enabled) VALUES (?, ?, ?, ?)")
|
||||
stmt, err = tx.Prepare(`INSERT INTO installations(identity, installation_id, timestamp, enabled)
|
||||
VALUES (?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -806,7 +902,9 @@ func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64,
|
|||
|
||||
// EnableInstallation enables the installation
|
||||
func (s *SQLLitePersistence) EnableInstallation(identity []byte, installationID string) error {
|
||||
stmt, err := s.db.Prepare("UPDATE installations SET enabled = 1 WHERE identity = ? AND installation_id = ?")
|
||||
stmt, err := s.db.Prepare(`UPDATE installations
|
||||
SET enabled = 1
|
||||
WHERE identity = ? AND installation_id = ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -819,7 +917,9 @@ func (s *SQLLitePersistence) EnableInstallation(identity []byte, installationID
|
|||
// DisableInstallation disable the installation
|
||||
func (s *SQLLitePersistence) DisableInstallation(identity []byte, installationID string) error {
|
||||
|
||||
stmt, err := s.db.Prepare("UPDATE installations SET enabled = 0 WHERE identity = ? AND installation_id = ?")
|
||||
stmt, err := s.db.Prepare(`UPDATE installations
|
||||
SET enabled = 0
|
||||
WHERE identity = ? AND installation_id = ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
|
@ -74,7 +73,7 @@ func (s *SQLLitePersistenceTestSuite) TestPrivateBundle() {
|
|||
anyPrivateBundle, err = s.service.GetAnyPrivateBundle(identity, []string{installationID})
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(anyPrivateBundle)
|
||||
s.True(proto.Equal(bundle.GetBundle(), anyPrivateBundle.GetBundle()), "It returns the same bundle")
|
||||
s.Equal(bundle.GetBundle().GetSignedPreKeys()[installationID].SignedPreKey, anyPrivateBundle.GetBundle().GetSignedPreKeys()[installationID].SignedPreKey, "It returns the same bundle")
|
||||
}
|
||||
|
||||
func (s *SQLLitePersistenceTestSuite) TestPublicBundle() {
|
||||
|
|
|
@ -107,7 +107,14 @@ func (s *Service) InitProtocol(address string, password string) error {
|
|||
if err := os.MkdirAll(filepath.Clean(s.dataDir), os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
persistence, err := chat.NewSQLLitePersistence(filepath.Join(s.dataDir, fmt.Sprintf("%x.db", address)), password)
|
||||
oldPath := filepath.Join(s.dataDir, fmt.Sprintf("%x.db", address))
|
||||
newPath := filepath.Join(s.dataDir, fmt.Sprintf("%s.db", s.installationID))
|
||||
|
||||
if err := chat.MigrateDBFile(oldPath, newPath, password); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
persistence, err := chat.NewSQLLitePersistence(newPath, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue