status-go/services/shhext/chat/sql_lite_persistence.go
2018-10-16 12:31:05 +02:00

661 lines
16 KiB
Go

package chat
import (
"crypto/ecdsa"
"database/sql"
"time"
"github.com/ethereum/go-ethereum/crypto"
_ "github.com/mutecomm/go-sqlcipher" // We require go sqlcipher that overrides default implementation
dr "github.com/status-im/doubleratchet"
"github.com/status-im/migrate"
"github.com/status-im/migrate/database/sqlcipher"
"github.com/status-im/migrate/source/go_bindata"
ecrypto "github.com/status-im/status-go/services/shhext/chat/crypto"
"github.com/status-im/status-go/services/shhext/chat/migrations"
)
// SQLLitePersistence represents a persistence service tied to an SQLite database
type SQLLitePersistence struct {
db *sql.DB
keysStorage dr.KeysStorage
sessionStorage dr.SessionStorage
}
// SQLLiteKeysStorage represents a keys persistence service tied to an SQLite database
type SQLLiteKeysStorage struct {
db *sql.DB
}
// SQLLiteSessionStorage represents a session persistence service tied to an SQLite database
type SQLLiteSessionStorage struct {
db *sql.DB
}
// NewSQLLitePersistence creates a new SQLLitePersistence instance, given a path and a key
func NewSQLLitePersistence(path string, key string) (*SQLLitePersistence, error) {
s := &SQLLitePersistence{}
if err := s.Open(path, key); err != nil {
return nil, err
}
s.keysStorage = NewSQLLiteKeysStorage(s.db)
s.sessionStorage = NewSQLLiteSessionStorage(s.db)
return s, nil
}
// NewSQLLiteKeysStorage creates a new SQLLiteKeysStorage instance associated with the specified database
func NewSQLLiteKeysStorage(db *sql.DB) *SQLLiteKeysStorage {
return &SQLLiteKeysStorage{
db: db,
}
}
// NewSQLLiteSessionStorage creates a new SQLLiteSessionStorage instance associated with the specified database
func NewSQLLiteSessionStorage(db *sql.DB) *SQLLiteSessionStorage {
return &SQLLiteSessionStorage{
db: db,
}
}
// GetKeysStorage returns the associated double ratchet KeysStorage object
func (s *SQLLitePersistence) GetKeysStorage() dr.KeysStorage {
return s.keysStorage
}
// GetSessionStorage returns the associated double ratchet SessionStorage object
func (s *SQLLitePersistence) GetSessionStorage() dr.SessionStorage {
return s.sessionStorage
}
// Open opens a file at the specified path
func (s *SQLLitePersistence) Open(path string, key string) error {
db, err := sql.Open("sqlite3", path)
if err != nil {
return err
}
// Disable concurrent access as not supported by the driver
db.SetMaxOpenConns(1)
if _, err = db.Exec("PRAGMA foreign_keys=ON"); err != nil {
return err
}
if _, err = db.Exec("PRAGMA key=ON"); err != nil {
return err
}
if _, err = db.Exec("PRAGMA cypher_page_size=4096"); err != nil {
return err
}
s.db = db
return s.setup()
}
// AddPrivateBundle adds the specified BundleContainer to the database
func (s *SQLLitePersistence) AddPrivateBundle(b *BundleContainer) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
for installationID, signedPreKey := range b.GetBundle().GetSignedPreKeys() {
stmt, err := tx.Prepare("INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, timestamp) VALUES(?, ?, ?, ?, ?)")
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(
b.GetBundle().GetIdentity(),
b.GetPrivateSignedPreKey(),
signedPreKey.GetSignedPreKey(),
installationID,
time.Now().UnixNano(),
)
if err != nil {
_ = tx.Rollback()
return err
}
}
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return err
}
return nil
}
// AddPublicBundle adds the specified Bundle to the database
func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
for installationID, signedPreKeyContainer := range b.GetSignedPreKeys() {
signedPreKey := signedPreKeyContainer.GetSignedPreKey()
insertStmt, err := tx.Prepare("INSERT INTO bundles(identity, signed_pre_key, installation_id, timestamp) VALUES( ?, ?, ?, ?)")
if err != nil {
return err
}
defer insertStmt.Close()
_, err = insertStmt.Exec(
b.GetIdentity(),
signedPreKey,
installationID,
time.Now().UnixNano(),
)
if err != nil {
_ = tx.Rollback()
return err
}
// Mark old bundles as expired
updateStmt, err := tx.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND installation_id = ? AND signed_pre_key != ?")
if err != nil {
return err
}
defer updateStmt.Close()
_, err = updateStmt.Exec(
b.GetIdentity(),
installationID,
signedPreKey,
)
if err != nil {
_ = tx.Rollback()
return err
}
}
return tx.Commit()
}
// GetAnyPrivateBundle retrieves any bundle from the database containing a private key
func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte) (*BundleContainer, error) {
stmt, err := s.db.Prepare("SELECT identity, private_key, signed_pre_key, installation_id, timestamp FROM bundles WHERE identity = ? AND expired = 0")
if err != nil {
return nil, err
}
defer stmt.Close()
var timestamp int64
var identity []byte
var privateKey []byte
rows, err := stmt.Query(myIdentityKey)
rowCount := 0
if err != nil {
return nil, err
}
defer rows.Close()
bundle := &Bundle{
SignedPreKeys: make(map[string]*SignedPreKey),
}
bundleContainer := &BundleContainer{
Bundle: bundle,
}
for rows.Next() {
var signedPreKey []byte
var installationID string
rowCount++
err = rows.Scan(
&identity,
&privateKey,
&signedPreKey,
&installationID,
&timestamp,
)
if err != nil {
return nil, err
}
// If there is a private key, we set the timestamp of the bundle container
if privateKey != nil {
bundleContainer.Timestamp = timestamp
}
bundle.SignedPreKeys[installationID] = &SignedPreKey{SignedPreKey: signedPreKey}
bundle.Identity = identity
}
// If no records are found or no record with private key, return nil
if rowCount == 0 || bundleContainer.Timestamp == 0 {
return nil, nil
}
return bundleContainer, nil
}
// GetPrivateKeyBundle retrieves a private key for a bundle from the database
func (s *SQLLitePersistence) GetPrivateKeyBundle(bundleID []byte) ([]byte, error) {
stmt, err := s.db.Prepare("SELECT private_key FROM bundles WHERE expired = 0 AND signed_pre_key = ? LIMIT 1")
if err != nil {
return nil, err
}
defer stmt.Close()
var privateKey []byte
err = stmt.QueryRow(bundleID).Scan(&privateKey)
switch err {
case sql.ErrNoRows:
return nil, nil
case nil:
return privateKey, nil
default:
return nil, err
}
}
// RatchetInfoConfirmed clears the ephemeral key in the RatchetInfo
// associated with the specified bundle ID and interlocutor identity public key
func (s *SQLLitePersistence) MarkBundleExpired(identity []byte) error {
stmt, err := s.db.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ?")
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(identity)
return err
}
// GetPublicBundle retrieves an existing Bundle for the specified public key from the database
func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey) (*Bundle, error) {
identity := crypto.CompressPubkey(publicKey)
stmt, err := s.db.Prepare("SELECT signed_pre_key,installation_id FROM bundles WHERE expired = 0 AND identity = ? ORDER BY timestamp DESC")
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query(identity)
rowCount := 0
if err != nil {
return nil, err
}
defer rows.Close()
bundle := &Bundle{
Identity: identity,
SignedPreKeys: make(map[string]*SignedPreKey),
}
for rows.Next() {
var signedPreKey []byte
var installationID string
rowCount++
err = rows.Scan(
&signedPreKey,
&installationID,
)
if err != nil {
return nil, err
}
bundle.SignedPreKeys[installationID] = &SignedPreKey{SignedPreKey: signedPreKey}
}
if rowCount == 0 {
return nil, nil
}
return bundle, nil
}
// AddRatchetInfo persists the specified ratchet info into the database
func (s *SQLLitePersistence) AddRatchetInfo(key []byte, identity []byte, bundleID []byte, ephemeralKey []byte, installationID string) error {
stmt, err := s.db.Prepare("INSERT INTO ratchet_info_v2(symmetric_key, identity, bundle_id, ephemeral_key, installation_id) VALUES(?, ?, ?, ?, ?)")
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(
key,
identity,
bundleID,
ephemeralKey,
installationID,
)
return err
}
// GetRatchetInfo retrieves the existing RatchetInfo for a specified bundle ID and interlocutor public key from the database
func (s *SQLLitePersistence) GetRatchetInfo(bundleID []byte, theirIdentity []byte, installationID string) (*RatchetInfo, error) {
stmt, err := s.db.Prepare("SELECT ratchet_info_v2.identity, ratchet_info_v2.symmetric_key, bundles.private_key, bundles.signed_pre_key, ratchet_info_v2.ephemeral_key, ratchet_info_v2.installation_id FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key WHERE ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? AND bundle_id = ? LIMIT 1")
if err != nil {
return nil, err
}
defer stmt.Close()
ratchetInfo := &RatchetInfo{
BundleID: bundleID,
}
err = stmt.QueryRow(theirIdentity, installationID, bundleID).Scan(
&ratchetInfo.Identity,
&ratchetInfo.Sk,
&ratchetInfo.PrivateKey,
&ratchetInfo.PublicKey,
&ratchetInfo.EphemeralKey,
&ratchetInfo.InstallationID,
)
switch err {
case sql.ErrNoRows:
return nil, nil
case nil:
ratchetInfo.ID = append(bundleID, []byte(ratchetInfo.InstallationID)...)
return ratchetInfo, nil
default:
return nil, err
}
}
// GetAnyRatchetInfo retrieves any existing RatchetInfo for a specified interlocutor public key from the database
func (s *SQLLitePersistence) GetAnyRatchetInfo(identity []byte, installationID string) (*RatchetInfo, error) {
stmt, err := s.db.Prepare("SELECT symmetric_key, bundles.private_key, signed_pre_key, bundle_id, ephemeral_key FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key WHERE expired = 0 AND ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? LIMIT 1")
if err != nil {
return nil, err
}
defer stmt.Close()
ratchetInfo := &RatchetInfo{
Identity: identity,
InstallationID: installationID,
}
err = stmt.QueryRow(identity, installationID).Scan(
&ratchetInfo.Sk,
&ratchetInfo.PrivateKey,
&ratchetInfo.PublicKey,
&ratchetInfo.BundleID,
&ratchetInfo.EphemeralKey,
)
switch err {
case sql.ErrNoRows:
return nil, nil
case nil:
ratchetInfo.ID = append(ratchetInfo.BundleID, []byte(installationID)...)
return ratchetInfo, nil
default:
return nil, err
}
}
// RatchetInfoConfirmed clears the ephemeral key in the RatchetInfo
// associated with the specified bundle ID and interlocutor identity public key
func (s *SQLLitePersistence) RatchetInfoConfirmed(bundleID []byte, theirIdentity []byte, installationID string) error {
stmt, err := s.db.Prepare("UPDATE ratchet_info_v2 SET ephemeral_key = NULL WHERE identity = ? AND bundle_id = ? AND installation_id = ?")
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(
theirIdentity,
bundleID,
installationID,
)
return err
}
// Get retrieves the message key for a specified public key and message number
func (s *SQLLiteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, error) {
var keyBytes []byte
var key [32]byte
stmt, err := s.db.Prepare("SELECT message_key FROM keys WHERE public_key = ? AND msg_num = ? LIMIT 1")
if err != nil {
return key, false, err
}
defer stmt.Close()
err = stmt.QueryRow(pubKey[:], msgNum).Scan(&keyBytes)
switch err {
case sql.ErrNoRows:
return key, false, nil
case nil:
copy(key[:], keyBytes)
return key, true, nil
default:
return key, false, err
}
}
// Put stores a key with the specified public key, message number and message key
func (s *SQLLiteKeysStorage) Put(pubKey dr.Key, msgNum uint, mk dr.Key) error {
stmt, err := s.db.Prepare("insert into keys(public_key, msg_num, message_key) values(?, ?, ?)")
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(
pubKey[:],
msgNum,
mk[:],
)
return err
}
// DeleteMk deletes the key with the specified public key and message key
func (s *SQLLiteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error {
stmt, err := s.db.Prepare("DELETE FROM keys WHERE public_key = ? AND msg_num = ?")
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(
pubKey[:],
msgNum,
)
return err
}
// DeletePk deletes the keys with the specified public key
func (s *SQLLiteKeysStorage) DeletePk(pubKey dr.Key) error {
stmt, err := s.db.Prepare("DELETE FROM keys WHERE public_key = ?")
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(
pubKey[:],
)
return err
}
// Count returns the count of keys with the specified public key
func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) {
stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys WHERE public_key = ?")
if err != nil {
return 0, err
}
defer stmt.Close()
var count uint
err = stmt.QueryRow(pubKey[:]).Scan(&count)
if err != nil {
return 0, err
}
return count, nil
}
// All returns nil
func (s *SQLLiteKeysStorage) All() (map[dr.Key]map[uint]dr.Key, error) {
return nil, nil
}
// Save persists the specified double ratchet state
func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
dhr := state.DHr[:]
dhs := state.DHs
dhsPublic := dhs.PublicKey()
dhsPrivate := dhs.PrivateKey()
pn := state.PN
step := state.Step
rootChainKey := state.RootCh.CK[:]
sendChainKey := state.SendCh.CK[:]
sendChainN := state.SendCh.N
recvChainKey := state.RecvCh.CK[:]
recvChainN := state.RecvCh.N
stmt, err := s.db.Prepare("insert into sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(
id,
dhr,
dhsPublic[:],
dhsPrivate[:],
rootChainKey,
sendChainKey,
sendChainN,
recvChainKey,
recvChainN,
pn,
step,
)
return err
}
// Load retrieves the double ratchet state for a given ID
func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
stmt, err := s.db.Prepare("SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step FROM sessions WHERE id = ?")
if err != nil {
return nil, err
}
defer stmt.Close()
var (
dhr []byte
dhsPublic []byte
dhsPrivate []byte
rootChainKey []byte
sendChainKey []byte
sendChainN uint
recvChainKey []byte
recvChainN uint
pn uint
step uint
)
err = stmt.QueryRow(id).Scan(
&dhr,
&dhsPublic,
&dhsPrivate,
&rootChainKey,
&sendChainKey,
&sendChainN,
&recvChainKey,
&recvChainN,
&pn,
&step,
)
switch err {
case sql.ErrNoRows:
return nil, nil
case nil:
state := dr.DefaultState(toKey(rootChainKey))
state.PN = uint32(pn)
state.Step = step
state.DHs = ecrypto.DHPair{
PrvKey: toKey(dhsPrivate),
PubKey: toKey(dhsPublic),
}
state.DHr = toKey(dhr)
state.SendCh.CK = toKey(sendChainKey)
state.SendCh.N = uint32(sendChainN)
state.RecvCh.CK = toKey(recvChainKey)
state.RecvCh.N = uint32(recvChainN)
return &state, nil
default:
return nil, err
}
}
func toKey(a []byte) dr.Key {
var k [32]byte
copy(k[:], a)
return k
}
func (s *SQLLitePersistence) setup() error {
resources := bindata.Resource(
migrations.AssetNames(),
func(name string) ([]byte, error) {
return migrations.Asset(name)
},
)
source, err := bindata.WithInstance(resources)
if err != nil {
return err
}
driver, err := sqlcipher.WithInstance(s.db, &sqlcipher.Config{})
if err != nil {
return err
}
m, err := migrate.NewWithInstance(
"go-bindata",
source,
"sqlcipher",
driver)
if err != nil {
return err
}
if err = m.Up(); err != migrate.ErrNoChange {
return err
}
return nil
}