Add versioning & tests, migrate db files (#1293)

We are preparing for the release of this to general public, so a few
things have been added:

1) Add versioning for bundles, and make refresh interval configurable
2) Move files to installationID so no metadata is leaked
3) Re-key using user password db
This commit is contained in:
Andrea Maria Piana 2018-11-28 12:34:39 +01:00 committed by GitHub
parent e60dbe3c1b
commit 38bb4d8ef3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 211 additions and 75 deletions

View File

@ -40,6 +40,8 @@ type EncryptionServiceConfig struct {
MaxKeep int
// How many keys do we store in total per session.
MaxMessageKeysPerSession int
// How long before we refresh the interval in milliseconds
BundleRefreshInterval int64
}
type IdentityAndIDPair [2]string
@ -51,6 +53,7 @@ func DefaultEncryptionServiceConfig(installationID string) EncryptionServiceConf
MaxSkip: 1000,
MaxKeep: 3000,
MaxMessageKeysPerSession: 2000,
BundleRefreshInterval: 14 * 24 * 60 * 60 * 1000,
InstallationID: installationID,
}
}
@ -107,7 +110,7 @@ func (s *EncryptionService) CreateBundle(privateKey *ecdsa.PrivateKey) (*Bundle,
}
// If the bundle has expired we create a new one
if bundleContainer != nil && bundleContainer.GetBundle().Timestamp < time.Now().AddDate(0, 0, -14).UnixNano() {
if bundleContainer != nil && bundleContainer.GetBundle().Timestamp < time.Now().Add(-1*time.Duration(s.config.BundleRefreshInterval)*time.Millisecond).UnixNano() {
// Mark sessions has expired
if err := s.persistence.MarkBundleExpired(bundleContainer.GetBundle().GetIdentity()); err != nil {
return nil, err

View File

@ -4,6 +4,7 @@ import (
"crypto/ecdsa"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"os"
"reflect"
@ -25,16 +26,33 @@ func TestEncryptionServiceTestSuite(t *testing.T) {
type EncryptionServiceTestSuite struct {
suite.Suite
alice *EncryptionService
bob *EncryptionService
alice *EncryptionService
bob *EncryptionService
aliceDBPath string
bobDBPath string
}
func (s *EncryptionServiceTestSuite) initDatabases() {
func (s *EncryptionServiceTestSuite) initDatabases(baseConfig *EncryptionServiceConfig) {
aliceDBFile, err := ioutil.TempFile(os.TempDir(), "alice")
s.Require().NoError(err)
aliceDBPath := aliceDBFile.Name()
bobDBFile, err := ioutil.TempFile(os.TempDir(), "bob")
s.Require().NoError(err)
bobDBPath := bobDBFile.Name()
s.aliceDBPath = aliceDBPath
s.bobDBPath = bobDBPath
if baseConfig == nil {
config := DefaultEncryptionServiceConfig(aliceInstallationID)
baseConfig = &config
}
const (
aliceDBPath = "/tmp/alice.db"
aliceDBKey = "alice"
bobDBPath = "/tmp/bob.db"
bobDBKey = "bob"
aliceDBKey = "alice"
bobDBKey = "bob"
)
alicePersistence, err := NewSQLLitePersistence(aliceDBPath, aliceDBKey)
@ -47,14 +65,20 @@ func (s *EncryptionServiceTestSuite) initDatabases() {
panic(err)
}
s.alice = NewEncryptionService(alicePersistence, DefaultEncryptionServiceConfig(aliceInstallationID))
s.bob = NewEncryptionService(bobPersistence, DefaultEncryptionServiceConfig(bobInstallationID))
baseConfig.InstallationID = aliceInstallationID
s.alice = NewEncryptionService(alicePersistence, *baseConfig)
baseConfig.InstallationID = bobInstallationID
s.bob = NewEncryptionService(bobPersistence, *baseConfig)
}
func (s *EncryptionServiceTestSuite) SetupTest() {
os.Remove("/tmp/alice.db")
os.Remove("/tmp/bob.db")
s.initDatabases()
s.initDatabases(nil)
}
func (s *EncryptionServiceTestSuite) TearDownTest() {
os.Remove(s.aliceDBPath)
os.Remove(s.bobDBPath)
}
func (s *EncryptionServiceTestSuite) TestCreateBundle() {
@ -749,6 +773,12 @@ func (s *EncryptionServiceTestSuite) TestBundleNotExisting() {
// A new bundle has been received
func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
config := DefaultEncryptionServiceConfig("none")
// Set up refresh interval to "always"
config.BundleRefreshInterval = 1000
s.initDatabases(&config)
bobKey, err := crypto.GenerateKey()
s.Require().NoError(err)
@ -756,23 +786,20 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
s.Require().NoError(err)
// Create bundles
bobBundle1, err := NewBundleContainer(bobKey, bobInstallationID)
bobBundle1, err := s.bob.CreateBundle(bobKey)
s.Require().NoError(err)
s.Require().Equal(uint32(1), bobBundle1.GetSignedPreKeys()[bobInstallationID].GetVersion())
err = SignBundle(bobKey, bobBundle1)
s.Require().NoError(err)
bobBundle2, err := NewBundleContainer(bobKey, bobInstallationID)
s.Require().NoError(err)
// We set the version
bobBundle2.GetBundle().GetSignedPreKeys()[bobInstallationID].Version = 1
err = SignBundle(bobKey, bobBundle2)
// Sleep the required time so that bundle is refreshed
time.Sleep(time.Duration(config.BundleRefreshInterval) * time.Millisecond)
// Create bundles
bobBundle2, err := s.bob.CreateBundle(bobKey)
s.Require().NoError(err)
s.Require().Equal(uint32(2), bobBundle2.GetSignedPreKeys()[bobInstallationID].GetVersion())
// We add the first bob bundle
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle1.GetBundle())
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle1)
s.Require().NoError(err)
// Alice sends a message
@ -786,10 +813,10 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
x3dhHeader1 := installationResponse1.GetX3DHHeader()
s.NotNil(x3dhHeader1)
s.Equal(bobBundle1.GetBundle().GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader1.GetId())
s.Equal(bobBundle1.GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader1.GetId())
// We add the second bob bundle
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle2.GetBundle())
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle2)
s.Require().NoError(err)
// Alice sends a message
@ -803,6 +830,6 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
x3dhHeader2 := installationResponse2.GetX3DHHeader()
s.NotNil(x3dhHeader2)
s.Equal(bobBundle2.GetBundle().GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader2.GetId())
s.Equal(bobBundle2.GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader2.GetId())
}

View File

@ -3,6 +3,8 @@ package chat
import (
"crypto/ecdsa"
"database/sql"
"fmt"
"os"
"strings"
"github.com/ethereum/go-ethereum/crypto"
@ -51,6 +53,64 @@ func NewSQLLitePersistence(path string, key string) (*SQLLitePersistence, error)
return s, nil
}
func MigrateDBFile(oldPath string, newPath string, key string) error {
_, err := os.Stat(oldPath)
// No files, nothing to do
if os.IsNotExist(err) {
return nil
}
// Any other error, throws
if err != nil {
return err
}
if err := os.Rename(oldPath, newPath); err != nil {
return err
}
// Migrate dev/nightly builds which used ON as a key for debugging
db, err := openDB(newPath, "ON")
if err != nil {
return err
}
keyString := fmt.Sprintf("PRAGMA rekey=%s", key)
if _, err = db.Exec(keyString); err != nil {
return err
}
return nil
}
func openDB(path string, key string) (*sql.DB, error) {
db, err := sql.Open("sqlite3", path)
if err != nil {
return nil, err
}
keyString := fmt.Sprintf("PRAGMA key=%s", key)
// Disable concurrent access as not supported by the driver
db.SetMaxOpenConns(1)
if _, err = db.Exec("PRAGMA foreign_keys=ON"); err != nil {
return nil, err
}
if _, err = db.Exec(keyString); err != nil {
return nil, err
}
if _, err = db.Exec("PRAGMA cypher_page_size=4096"); err != nil {
return nil, err
}
return db, nil
}
// NewSQLLiteKeysStorage creates a new SQLLiteKeysStorage instance associated with the specified database
func NewSQLLiteKeysStorage(db *sql.DB) *SQLLiteKeysStorage {
return &SQLLiteKeysStorage{
@ -77,26 +137,11 @@ func (s *SQLLitePersistence) GetSessionStorage() dr.SessionStorage {
// Open opens a file at the specified path
func (s *SQLLitePersistence) Open(path string, key string) error {
db, err := sql.Open("sqlite3", path)
db, err := openDB(path, key)
if err != nil {
return err
}
// Disable concurrent access as not supported by the driver
db.SetMaxOpenConns(1)
if _, err = db.Exec("PRAGMA foreign_keys=ON"); err != nil {
return err
}
if _, err = db.Exec("PRAGMA key=ON"); err != nil {
return err
}
if _, err = db.Exec("PRAGMA cypher_page_size=4096"); err != nil {
return err
}
s.db = db
return s.setup()
@ -111,7 +156,11 @@ func (s *SQLLitePersistence) AddPrivateBundle(bc *BundleContainer) error {
for installationID, signedPreKey := range bc.GetBundle().GetSignedPreKeys() {
var version uint32
stmt, err := tx.Prepare("SELECT version FROM bundles WHERE installation_id = ? AND identity = ? ORDER BY version DESC LIMIT 1")
stmt, err := tx.Prepare(`SELECT version
FROM bundles
WHERE installation_id = ? AND identity = ?
ORDER BY version DESC
LIMIT 1`)
if err != nil {
return err
}
@ -123,7 +172,8 @@ func (s *SQLLitePersistence) AddPrivateBundle(bc *BundleContainer) error {
return err
}
stmt, err = tx.Prepare("INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, version, timestamp) VALUES(?, ?, ?, ?, ?, ?)")
stmt, err = tx.Prepare(`INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, version, timestamp)
VALUES(?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
@ -162,7 +212,8 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
for installationID, signedPreKeyContainer := range b.GetSignedPreKeys() {
signedPreKey := signedPreKeyContainer.GetSignedPreKey()
version := signedPreKeyContainer.GetVersion()
insertStmt, err := tx.Prepare("INSERT INTO bundles(identity, signed_pre_key, installation_id, version, timestamp) VALUES( ?, ?, ?, ?, ?)")
insertStmt, err := tx.Prepare(`INSERT INTO bundles(identity, signed_pre_key, installation_id, version, timestamp)
VALUES( ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
@ -180,7 +231,9 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
return err
}
// Mark old bundles as expired
updateStmt, err := tx.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND installation_id = ? AND version < ?")
updateStmt, err := tx.Prepare(`UPDATE bundles
SET expired = 1
WHERE identity = ? AND installation_id = ? AND version < ?`)
if err != nil {
return err
}
@ -205,7 +258,9 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installationIDs []string) (*BundleContainer, error) {
/* #nosec */
statement := "SELECT identity, private_key, signed_pre_key, installation_id, timestamp FROM bundles WHERE expired = 0 AND identity = ? AND installation_id IN (?" + strings.Repeat(",?", len(installationIDs)-1) + ")"
statement := `SELECT identity, private_key, signed_pre_key, installation_id, timestamp, version
FROM bundles
WHERE expired = 0 AND identity = ? AND installation_id IN (?` + strings.Repeat(",?", len(installationIDs)-1) + ")"
stmt, err := s.db.Prepare(statement)
if err != nil {
return nil, err
@ -215,6 +270,7 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat
var timestamp int64
var identity []byte
var privateKey []byte
var version uint32
args := make([]interface{}, len(installationIDs)+1)
args[0] = myIdentityKey
@ -249,6 +305,7 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat
&signedPreKey,
&installationID,
&timestamp,
&version,
)
if err != nil {
return nil, err
@ -258,7 +315,7 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat
bundle.Timestamp = timestamp
}
bundle.SignedPreKeys[installationID] = &SignedPreKey{SignedPreKey: signedPreKey}
bundle.SignedPreKeys[installationID] = &SignedPreKey{SignedPreKey: signedPreKey, Version: version}
bundle.Identity = identity
}
@ -273,7 +330,9 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat
// GetPrivateKeyBundle retrieves a private key for a bundle from the database
func (s *SQLLitePersistence) GetPrivateKeyBundle(bundleID []byte) ([]byte, error) {
stmt, err := s.db.Prepare("SELECT private_key FROM bundles WHERE expired = 0 AND signed_pre_key = ? LIMIT 1")
stmt, err := s.db.Prepare(`SELECT private_key
FROM bundles
WHERE expired = 0 AND signed_pre_key = ? LIMIT 1`)
if err != nil {
return nil, err
}
@ -294,7 +353,9 @@ func (s *SQLLitePersistence) GetPrivateKeyBundle(bundleID []byte) ([]byte, error
// MarkBundleExpired expires any private bundle for a given identity
func (s *SQLLitePersistence) MarkBundleExpired(identity []byte) error {
stmt, err := s.db.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND private_key IS NOT NULL")
stmt, err := s.db.Prepare(`UPDATE bundles
SET expired = 1
WHERE identity = ? AND private_key IS NOT NULL`)
if err != nil {
return err
}
@ -315,7 +376,10 @@ func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, install
identity := crypto.CompressPubkey(publicKey)
/* #nosec */
statement := "SELECT signed_pre_key,installation_id, version FROM bundles WHERE expired = 0 AND identity = ? AND installation_id IN (?" + strings.Repeat(",?", len(installationIDs)-1) + ") ORDER BY version DESC"
statement := `SELECT signed_pre_key,installation_id, version
FROM bundles
WHERE expired = 0 AND identity = ? AND installation_id IN (?` + strings.Repeat(",?", len(installationIDs)-1) + `)
ORDER BY version DESC`
stmt, err := s.db.Prepare(statement)
if err != nil {
return nil, err
@ -373,7 +437,8 @@ func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, install
// AddRatchetInfo persists the specified ratchet info into the database
func (s *SQLLitePersistence) AddRatchetInfo(key []byte, identity []byte, bundleID []byte, ephemeralKey []byte, installationID string) error {
stmt, err := s.db.Prepare("INSERT INTO ratchet_info_v2(symmetric_key, identity, bundle_id, ephemeral_key, installation_id) VALUES(?, ?, ?, ?, ?)")
stmt, err := s.db.Prepare(`INSERT INTO ratchet_info_v2(symmetric_key, identity, bundle_id, ephemeral_key, installation_id)
VALUES(?, ?, ?, ?, ?)`)
if err != nil {
return err
}
@ -392,7 +457,10 @@ func (s *SQLLitePersistence) AddRatchetInfo(key []byte, identity []byte, bundleI
// GetRatchetInfo retrieves the existing RatchetInfo for a specified bundle ID and interlocutor public key from the database
func (s *SQLLitePersistence) GetRatchetInfo(bundleID []byte, theirIdentity []byte, installationID string) (*RatchetInfo, error) {
stmt, err := s.db.Prepare("SELECT ratchet_info_v2.identity, ratchet_info_v2.symmetric_key, bundles.private_key, bundles.signed_pre_key, ratchet_info_v2.ephemeral_key, ratchet_info_v2.installation_id FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key WHERE ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? AND bundle_id = ? LIMIT 1")
stmt, err := s.db.Prepare(`SELECT ratchet_info_v2.identity, ratchet_info_v2.symmetric_key, bundles.private_key, bundles.signed_pre_key, ratchet_info_v2.ephemeral_key, ratchet_info_v2.installation_id
FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key
WHERE ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? AND bundle_id = ?
LIMIT 1`)
if err != nil {
return nil, err
}
@ -423,7 +491,10 @@ func (s *SQLLitePersistence) GetRatchetInfo(bundleID []byte, theirIdentity []byt
// GetAnyRatchetInfo retrieves any existing RatchetInfo for a specified interlocutor public key from the database
func (s *SQLLitePersistence) GetAnyRatchetInfo(identity []byte, installationID string) (*RatchetInfo, error) {
stmt, err := s.db.Prepare("SELECT symmetric_key, bundles.private_key, signed_pre_key, bundle_id, ephemeral_key FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key WHERE expired = 0 AND ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? LIMIT 1")
stmt, err := s.db.Prepare(`SELECT symmetric_key, bundles.private_key, signed_pre_key, bundle_id, ephemeral_key
FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key
WHERE expired = 0 AND ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ?
LIMIT 1`)
if err != nil {
return nil, err
}
@ -455,7 +526,9 @@ func (s *SQLLitePersistence) GetAnyRatchetInfo(identity []byte, installationID s
// RatchetInfoConfirmed clears the ephemeral key in the RatchetInfo
// associated with the specified bundle ID and interlocutor identity public key
func (s *SQLLitePersistence) RatchetInfoConfirmed(bundleID []byte, theirIdentity []byte, installationID string) error {
stmt, err := s.db.Prepare("UPDATE ratchet_info_v2 SET ephemeral_key = NULL WHERE identity = ? AND bundle_id = ? AND installation_id = ?")
stmt, err := s.db.Prepare(`UPDATE ratchet_info_v2
SET ephemeral_key = NULL
WHERE identity = ? AND bundle_id = ? AND installation_id = ?`)
if err != nil {
return err
}
@ -474,7 +547,10 @@ func (s *SQLLitePersistence) RatchetInfoConfirmed(bundleID []byte, theirIdentity
func (s *SQLLiteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, error) {
var keyBytes []byte
var key [32]byte
stmt, err := s.db.Prepare("SELECT message_key FROM keys WHERE public_key = ? AND msg_num = ? LIMIT 1")
stmt, err := s.db.Prepare(`SELECT message_key
FROM keys
WHERE public_key = ? AND msg_num = ?
LIMIT 1`)
if err != nil {
return key, false, err
@ -495,7 +571,8 @@ func (s *SQLLiteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, erro
// Put stores a key with the specified public key, message number and message key
func (s *SQLLiteKeysStorage) Put(sessionID []byte, pubKey dr.Key, msgNum uint, mk dr.Key, seqNum uint) error {
stmt, err := s.db.Prepare("insert into keys(session_id, public_key, msg_num, message_key, seq_num) values(?, ?, ?, ?, ?)")
stmt, err := s.db.Prepare(`INSERT INTO keys(session_id, public_key, msg_num, message_key, seq_num)
VALUES(?, ?, ?, ?, ?)`)
if err != nil {
return err
}
@ -514,7 +591,8 @@ func (s *SQLLiteKeysStorage) Put(sessionID []byte, pubKey dr.Key, msgNum uint, m
// DeleteOldMks caps remove any key < seq_num, included
func (s *SQLLiteKeysStorage) DeleteOldMks(sessionID []byte, deleteUntil uint) error {
stmt, err := s.db.Prepare("DELETE FROM keys WHERE session_id = ? AND seq_num <= ?")
stmt, err := s.db.Prepare(`DELETE FROM keys
WHERE session_id = ? AND seq_num <= ?`)
if err != nil {
return err
}
@ -530,7 +608,8 @@ func (s *SQLLiteKeysStorage) DeleteOldMks(sessionID []byte, deleteUntil uint) er
// TruncateMks caps the number of keys to maxKeysPerSession deleting them in FIFO fashion
func (s *SQLLiteKeysStorage) TruncateMks(sessionID []byte, maxKeysPerSession int) error {
stmt, err := s.db.Prepare("DELETE FROM keys WHERE rowid IN (SELECT rowid FROM keys WHERE session_id = ? ORDER BY seq_num DESC LIMIT ? OFFSET ?)")
stmt, err := s.db.Prepare(`DELETE FROM keys
WHERE rowid IN (SELECT rowid FROM keys WHERE session_id = ? ORDER BY seq_num DESC LIMIT ? OFFSET ?)`)
if err != nil {
return err
}
@ -548,7 +627,8 @@ func (s *SQLLiteKeysStorage) TruncateMks(sessionID []byte, maxKeysPerSession int
// DeleteMk deletes the key with the specified public key and message key
func (s *SQLLiteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error {
stmt, err := s.db.Prepare("DELETE FROM keys WHERE public_key = ? AND msg_num = ?")
stmt, err := s.db.Prepare(`DELETE FROM keys
WHERE public_key = ? AND msg_num = ?`)
if err != nil {
return err
}
@ -564,7 +644,9 @@ func (s *SQLLiteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error {
// Count returns the count of keys with the specified public key
func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) {
stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys WHERE public_key = ?")
stmt, err := s.db.Prepare(`SELECT COUNT(1)
FROM keys
WHERE public_key = ?`)
if err != nil {
return 0, err
}
@ -581,7 +663,8 @@ func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) {
// CountAll returns the count of keys with the specified public key
func (s *SQLLiteKeysStorage) CountAll() (uint, error) {
stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys")
stmt, err := s.db.Prepare(`SELECT COUNT(1)
FROM keys`)
if err != nil {
return 0, err
}
@ -619,7 +702,8 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
recvChainKey := state.RecvCh.CK[:]
recvChainN := state.RecvCh.N
stmt, err := s.db.Prepare("insert into sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
stmt, err := s.db.Prepare(`INSERT INTO sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count)
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
@ -645,7 +729,9 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
// Load retrieves the double ratchet state for a given ID
func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
stmt, err := s.db.Prepare("SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count FROM sessions WHERE id = ?")
stmt, err := s.db.Prepare(`SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count
FROM sessions
WHERE id = ?`)
if err != nil {
return nil, err
}
@ -710,7 +796,11 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
// GetActiveInstallations returns the active installations for a given identity
func (s *SQLLitePersistence) GetActiveInstallations(maxInstallations int, identity []byte) ([]string, error) {
stmt, err := s.db.Prepare("SELECT installation_id FROM installations WHERE enabled = 1 AND identity = ? ORDER BY timestamp DESC LIMIT ?")
stmt, err := s.db.Prepare(`SELECT installation_id
FROM installations
WHERE enabled = 1 AND identity = ?
ORDER BY timestamp DESC
LIMIT ?`)
if err != nil {
return nil, err
}
@ -744,7 +834,10 @@ func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64,
}
for _, installationID := range installationIDs {
stmt, err := tx.Prepare("SELECT enabled FROM installations WHERE identity = ? AND installation_id = ? LIMIT 1")
stmt, err := tx.Prepare(`SELECT enabled
FROM installations
WHERE identity = ? AND installation_id = ?
LIMIT 1`)
if err != nil {
return err
}
@ -759,7 +852,9 @@ func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64,
// We update timestamp if present without changing enabled
if err != sql.ErrNoRows {
stmt, err = tx.Prepare("UPDATE installations SET timestamp = ?, enabled = ? WHERE identity = ? AND installation_id = ?")
stmt, err = tx.Prepare(`UPDATE installations
SET timestamp = ?, enabled = ?
WHERE identity = ? AND installation_id = ?`)
if err != nil {
return err
}
@ -776,7 +871,8 @@ func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64,
defer stmt.Close()
} else {
stmt, err = tx.Prepare("INSERT INTO installations(identity, installation_id, timestamp, enabled) VALUES (?, ?, ?, ?)")
stmt, err = tx.Prepare(`INSERT INTO installations(identity, installation_id, timestamp, enabled)
VALUES (?, ?, ?, ?)`)
if err != nil {
return err
}
@ -806,7 +902,9 @@ func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64,
// EnableInstallation enables the installation
func (s *SQLLitePersistence) EnableInstallation(identity []byte, installationID string) error {
stmt, err := s.db.Prepare("UPDATE installations SET enabled = 1 WHERE identity = ? AND installation_id = ?")
stmt, err := s.db.Prepare(`UPDATE installations
SET enabled = 1
WHERE identity = ? AND installation_id = ?`)
if err != nil {
return err
}
@ -819,7 +917,9 @@ func (s *SQLLitePersistence) EnableInstallation(identity []byte, installationID
// DisableInstallation disable the installation
func (s *SQLLitePersistence) DisableInstallation(identity []byte, installationID string) error {
stmt, err := s.db.Prepare("UPDATE installations SET enabled = 0 WHERE identity = ? AND installation_id = ?")
stmt, err := s.db.Prepare(`UPDATE installations
SET enabled = 0
WHERE identity = ? AND installation_id = ?`)
if err != nil {
return err
}

View File

@ -6,7 +6,6 @@ import (
"testing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/suite"
)
@ -74,7 +73,7 @@ func (s *SQLLitePersistenceTestSuite) TestPrivateBundle() {
anyPrivateBundle, err = s.service.GetAnyPrivateBundle(identity, []string{installationID})
s.Require().NoError(err)
s.NotNil(anyPrivateBundle)
s.True(proto.Equal(bundle.GetBundle(), anyPrivateBundle.GetBundle()), "It returns the same bundle")
s.Equal(bundle.GetBundle().GetSignedPreKeys()[installationID].SignedPreKey, anyPrivateBundle.GetBundle().GetSignedPreKeys()[installationID].SignedPreKey, "It returns the same bundle")
}
func (s *SQLLitePersistenceTestSuite) TestPublicBundle() {

View File

@ -107,7 +107,14 @@ func (s *Service) InitProtocol(address string, password string) error {
if err := os.MkdirAll(filepath.Clean(s.dataDir), os.ModePerm); err != nil {
return err
}
persistence, err := chat.NewSQLLitePersistence(filepath.Join(s.dataDir, fmt.Sprintf("%x.db", address)), password)
oldPath := filepath.Join(s.dataDir, fmt.Sprintf("%x.db", address))
newPath := filepath.Join(s.dataDir, fmt.Sprintf("%s.db", s.installationID))
if err := chat.MigrateDBFile(oldPath, newPath, password); err != nil {
return err
}
persistence, err := chat.NewSQLLitePersistence(newPath, password)
if err != nil {
return err
}