1012 lines
23 KiB
Go
1012 lines
23 KiB
Go
package encryption
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"database/sql"
|
|
"errors"
|
|
"strings"
|
|
|
|
dr "github.com/status-im/doubleratchet"
|
|
|
|
"github.com/status-im/status-go/eth-node/crypto"
|
|
|
|
"github.com/status-im/status-go/protocol/encryption/multidevice"
|
|
)
|
|
|
|
// RatchetInfo holds the current ratchet state.
|
|
type RatchetInfo struct {
|
|
ID []byte
|
|
Sk []byte
|
|
PrivateKey []byte
|
|
PublicKey []byte
|
|
Identity []byte
|
|
BundleID []byte
|
|
EphemeralKey []byte
|
|
InstallationID string
|
|
}
|
|
|
|
// A safe max number of rows.
|
|
const maxNumberOfRows = 100000000
|
|
|
|
type sqlitePersistence struct {
|
|
DB *sql.DB
|
|
keysStorage dr.KeysStorage
|
|
sessionStorage dr.SessionStorage
|
|
}
|
|
|
|
func newSQLitePersistence(db *sql.DB) *sqlitePersistence {
|
|
return &sqlitePersistence{
|
|
DB: db,
|
|
keysStorage: newSQLiteKeysStorage(db),
|
|
sessionStorage: newSQLiteSessionStorage(db),
|
|
}
|
|
}
|
|
|
|
// GetKeysStorage returns the associated double ratchet KeysStorage object
|
|
func (s *sqlitePersistence) KeysStorage() dr.KeysStorage {
|
|
return s.keysStorage
|
|
}
|
|
|
|
// GetSessionStorage returns the associated double ratchet SessionStorage object
|
|
func (s *sqlitePersistence) SessionStorage() dr.SessionStorage {
|
|
return s.sessionStorage
|
|
}
|
|
|
|
// AddPrivateBundle adds the specified BundleContainer to the database
|
|
func (s *sqlitePersistence) AddPrivateBundle(bc *BundleContainer) error {
|
|
tx, err := s.DB.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for installationID, signedPreKey := range bc.GetBundle().GetSignedPreKeys() {
|
|
var version uint32
|
|
stmt, err := tx.Prepare(`SELECT version
|
|
FROM bundles
|
|
WHERE installation_id = ? AND identity = ?
|
|
ORDER BY version DESC
|
|
LIMIT 1`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
defer stmt.Close()
|
|
|
|
err = stmt.QueryRow(installationID, bc.GetBundle().GetIdentity()).Scan(&version)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
return err
|
|
}
|
|
|
|
stmt, err = tx.Prepare(`INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, version, timestamp)
|
|
VALUES(?, ?, ?, ?, ?, ?)`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
_, err = stmt.Exec(
|
|
bc.GetBundle().GetIdentity(),
|
|
bc.GetPrivateSignedPreKey(),
|
|
signedPreKey.GetSignedPreKey(),
|
|
installationID,
|
|
version+1,
|
|
bc.GetBundle().GetTimestamp(),
|
|
)
|
|
if err != nil {
|
|
_ = tx.Rollback()
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
_ = tx.Rollback()
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddPublicBundle adds the specified Bundle to the database
|
|
func (s *sqlitePersistence) AddPublicBundle(b *Bundle) error {
|
|
tx, err := s.DB.Begin()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for installationID, signedPreKeyContainer := range b.GetSignedPreKeys() {
|
|
signedPreKey := signedPreKeyContainer.GetSignedPreKey()
|
|
version := signedPreKeyContainer.GetVersion()
|
|
insertStmt, err := tx.Prepare(`INSERT INTO bundles(identity, signed_pre_key, installation_id, version, timestamp)
|
|
VALUES( ?, ?, ?, ?, ?)`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer insertStmt.Close()
|
|
|
|
_, err = insertStmt.Exec(
|
|
b.GetIdentity(),
|
|
signedPreKey,
|
|
installationID,
|
|
version,
|
|
b.GetTimestamp(),
|
|
)
|
|
if err != nil {
|
|
_ = tx.Rollback()
|
|
return err
|
|
}
|
|
// Mark old bundles as expired
|
|
updateStmt, err := tx.Prepare(`UPDATE bundles
|
|
SET expired = 1
|
|
WHERE identity = ? AND installation_id = ? AND version < ?`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer updateStmt.Close()
|
|
|
|
_, err = updateStmt.Exec(
|
|
b.GetIdentity(),
|
|
installationID,
|
|
version,
|
|
)
|
|
if err != nil {
|
|
_ = tx.Rollback()
|
|
return err
|
|
}
|
|
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// GetAnyPrivateBundle retrieves any bundle from the database containing a private key
|
|
func (s *sqlitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installations []*multidevice.Installation) (*BundleContainer, error) {
|
|
|
|
versions := make(map[string]uint32)
|
|
/* #nosec */
|
|
statement := `SELECT identity, private_key, signed_pre_key, installation_id, timestamp, version
|
|
FROM bundles
|
|
WHERE expired = 0 AND identity = ? AND installation_id IN (?` + strings.Repeat(",?", len(installations)-1) + ")"
|
|
stmt, err := s.DB.Prepare(statement)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
var timestamp int64
|
|
var identity []byte
|
|
var privateKey []byte
|
|
var version uint32
|
|
|
|
args := make([]interface{}, len(installations)+1)
|
|
args[0] = myIdentityKey
|
|
for i, installation := range installations {
|
|
// Lookup up map for versions
|
|
versions[installation.ID] = installation.Version
|
|
|
|
args[i+1] = installation.ID
|
|
}
|
|
|
|
rows, err := stmt.Query(args...)
|
|
rowCount := 0
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer rows.Close()
|
|
|
|
bundle := &Bundle{
|
|
SignedPreKeys: make(map[string]*SignedPreKey),
|
|
}
|
|
|
|
bundleContainer := &BundleContainer{
|
|
Bundle: bundle,
|
|
}
|
|
|
|
for rows.Next() {
|
|
var signedPreKey []byte
|
|
var installationID string
|
|
rowCount++
|
|
err = rows.Scan(
|
|
&identity,
|
|
&privateKey,
|
|
&signedPreKey,
|
|
&installationID,
|
|
×tamp,
|
|
&version,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// If there is a private key, we set the timestamp of the bundle container
|
|
if privateKey != nil {
|
|
bundle.Timestamp = timestamp
|
|
}
|
|
|
|
bundle.SignedPreKeys[installationID] = &SignedPreKey{
|
|
SignedPreKey: signedPreKey,
|
|
Version: version,
|
|
ProtocolVersion: versions[installationID],
|
|
}
|
|
bundle.Identity = identity
|
|
}
|
|
|
|
// If no records are found or no record with private key, return nil
|
|
if rowCount == 0 || bundleContainer.GetBundle().Timestamp == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
return bundleContainer, nil
|
|
|
|
}
|
|
|
|
// GetPrivateKeyBundle retrieves a private key for a bundle from the database
|
|
func (s *sqlitePersistence) GetPrivateKeyBundle(bundleID []byte) ([]byte, error) {
|
|
stmt, err := s.DB.Prepare(`SELECT private_key
|
|
FROM bundles
|
|
WHERE signed_pre_key = ? LIMIT 1`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
var privateKey []byte
|
|
|
|
err = stmt.QueryRow(bundleID).Scan(&privateKey)
|
|
switch err {
|
|
case sql.ErrNoRows:
|
|
return nil, nil
|
|
case nil:
|
|
return privateKey, nil
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// MarkBundleExpired expires any private bundle for a given identity
|
|
func (s *sqlitePersistence) MarkBundleExpired(identity []byte) error {
|
|
stmt, err := s.DB.Prepare(`UPDATE bundles
|
|
SET expired = 1
|
|
WHERE identity = ? AND private_key IS NOT NULL`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
_, err = stmt.Exec(identity)
|
|
|
|
return err
|
|
}
|
|
|
|
// GetPublicBundle retrieves an existing Bundle for the specified public key from the database
|
|
func (s *sqlitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, installations []*multidevice.Installation) (*Bundle, error) {
|
|
|
|
if len(installations) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
versions := make(map[string]uint32)
|
|
identity := crypto.CompressPubkey(publicKey)
|
|
|
|
/* #nosec */
|
|
statement := `SELECT signed_pre_key,installation_id, version
|
|
FROM bundles
|
|
WHERE expired = 0 AND identity = ? AND installation_id IN (?` + strings.Repeat(",?", len(installations)-1) + `)
|
|
ORDER BY version DESC`
|
|
stmt, err := s.DB.Prepare(statement)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
args := make([]interface{}, len(installations)+1)
|
|
args[0] = identity
|
|
for i, installation := range installations {
|
|
// Lookup up map for versions
|
|
versions[installation.ID] = installation.Version
|
|
args[i+1] = installation.ID
|
|
}
|
|
|
|
rows, err := stmt.Query(args...)
|
|
rowCount := 0
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer rows.Close()
|
|
|
|
bundle := &Bundle{
|
|
Identity: identity,
|
|
SignedPreKeys: make(map[string]*SignedPreKey),
|
|
}
|
|
|
|
for rows.Next() {
|
|
var signedPreKey []byte
|
|
var installationID string
|
|
var version uint32
|
|
rowCount++
|
|
err = rows.Scan(
|
|
&signedPreKey,
|
|
&installationID,
|
|
&version,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
bundle.SignedPreKeys[installationID] = &SignedPreKey{
|
|
SignedPreKey: signedPreKey,
|
|
Version: version,
|
|
ProtocolVersion: versions[installationID],
|
|
}
|
|
|
|
}
|
|
|
|
if rowCount == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
return bundle, nil
|
|
|
|
}
|
|
|
|
// AddRatchetInfo persists the specified ratchet info into the database
|
|
func (s *sqlitePersistence) AddRatchetInfo(key []byte, identity []byte, bundleID []byte, ephemeralKey []byte, installationID string) error {
|
|
stmt, err := s.DB.Prepare(`INSERT INTO ratchet_info_v2(symmetric_key, identity, bundle_id, ephemeral_key, installation_id)
|
|
VALUES(?, ?, ?, ?, ?)`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
_, err = stmt.Exec(
|
|
key,
|
|
identity,
|
|
bundleID,
|
|
ephemeralKey,
|
|
installationID,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
// GetRatchetInfo retrieves the existing RatchetInfo for a specified bundle ID and interlocutor public key from the database
|
|
func (s *sqlitePersistence) GetRatchetInfo(bundleID []byte, theirIdentity []byte, installationID string) (*RatchetInfo, error) {
|
|
stmt, err := s.DB.Prepare(`SELECT ratchet_info_v2.identity, ratchet_info_v2.symmetric_key, bundles.private_key, bundles.signed_pre_key, ratchet_info_v2.ephemeral_key, ratchet_info_v2.installation_id
|
|
FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key
|
|
WHERE ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? AND bundle_id = ?
|
|
LIMIT 1`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
ratchetInfo := &RatchetInfo{
|
|
BundleID: bundleID,
|
|
}
|
|
|
|
err = stmt.QueryRow(theirIdentity, installationID, bundleID).Scan(
|
|
&ratchetInfo.Identity,
|
|
&ratchetInfo.Sk,
|
|
&ratchetInfo.PrivateKey,
|
|
&ratchetInfo.PublicKey,
|
|
&ratchetInfo.EphemeralKey,
|
|
&ratchetInfo.InstallationID,
|
|
)
|
|
switch err {
|
|
case sql.ErrNoRows:
|
|
return nil, nil
|
|
case nil:
|
|
ratchetInfo.ID = append(bundleID, []byte(ratchetInfo.InstallationID)...)
|
|
return ratchetInfo, nil
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// GetAnyRatchetInfo retrieves any existing RatchetInfo for a specified interlocutor public key from the database
|
|
func (s *sqlitePersistence) GetAnyRatchetInfo(identity []byte, installationID string) (*RatchetInfo, error) {
|
|
stmt, err := s.DB.Prepare(`SELECT symmetric_key, bundles.private_key, signed_pre_key, bundle_id, ephemeral_key
|
|
FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key
|
|
WHERE expired = 0 AND ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ?
|
|
LIMIT 1`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
ratchetInfo := &RatchetInfo{
|
|
Identity: identity,
|
|
InstallationID: installationID,
|
|
}
|
|
|
|
err = stmt.QueryRow(identity, installationID).Scan(
|
|
&ratchetInfo.Sk,
|
|
&ratchetInfo.PrivateKey,
|
|
&ratchetInfo.PublicKey,
|
|
&ratchetInfo.BundleID,
|
|
&ratchetInfo.EphemeralKey,
|
|
)
|
|
switch err {
|
|
case sql.ErrNoRows:
|
|
return nil, nil
|
|
case nil:
|
|
ratchetInfo.ID = append(ratchetInfo.BundleID, []byte(installationID)...)
|
|
return ratchetInfo, nil
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// RatchetInfoConfirmed clears the ephemeral key in the RatchetInfo
|
|
// associated with the specified bundle ID and interlocutor identity public key
|
|
func (s *sqlitePersistence) RatchetInfoConfirmed(bundleID []byte, theirIdentity []byte, installationID string) error {
|
|
stmt, err := s.DB.Prepare(`UPDATE ratchet_info_v2
|
|
SET ephemeral_key = NULL
|
|
WHERE identity = ? AND bundle_id = ? AND installation_id = ?`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
_, err = stmt.Exec(
|
|
theirIdentity,
|
|
bundleID,
|
|
installationID,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
type sqliteKeysStorage struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func newSQLiteKeysStorage(db *sql.DB) *sqliteKeysStorage {
|
|
return &sqliteKeysStorage{
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
// Get retrieves the message key for a specified public key and message number
|
|
func (s *sqliteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, error) {
|
|
var key []byte
|
|
stmt, err := s.db.Prepare(`SELECT message_key
|
|
FROM keys
|
|
WHERE public_key = ? AND msg_num = ?
|
|
LIMIT 1`)
|
|
|
|
if err != nil {
|
|
return key, false, err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
err = stmt.QueryRow(pubKey, msgNum).Scan(&key)
|
|
switch err {
|
|
case sql.ErrNoRows:
|
|
return key, false, nil
|
|
case nil:
|
|
return key, true, nil
|
|
default:
|
|
return key, false, err
|
|
}
|
|
}
|
|
|
|
// Put stores a key with the specified public key, message number and message key
|
|
func (s *sqliteKeysStorage) Put(sessionID []byte, pubKey dr.Key, msgNum uint, mk dr.Key, seqNum uint) error {
|
|
stmt, err := s.db.Prepare(`INSERT INTO keys(session_id, public_key, msg_num, message_key, seq_num)
|
|
VALUES(?, ?, ?, ?, ?)`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
_, err = stmt.Exec(
|
|
sessionID,
|
|
pubKey,
|
|
msgNum,
|
|
mk,
|
|
seqNum,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
// DeleteOldMks caps remove any key < seq_num, included
|
|
func (s *sqliteKeysStorage) DeleteOldMks(sessionID []byte, deleteUntil uint) error {
|
|
stmt, err := s.db.Prepare(`DELETE FROM keys
|
|
WHERE session_id = ? AND seq_num <= ?`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
_, err = stmt.Exec(
|
|
sessionID,
|
|
deleteUntil,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
// TruncateMks caps the number of keys to maxKeysPerSession deleting them in FIFO fashion
|
|
func (s *sqliteKeysStorage) TruncateMks(sessionID []byte, maxKeysPerSession int) error {
|
|
stmt, err := s.db.Prepare(`DELETE FROM keys
|
|
WHERE rowid IN (SELECT rowid FROM keys WHERE session_id = ? ORDER BY seq_num DESC LIMIT ? OFFSET ?)`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
_, err = stmt.Exec(
|
|
sessionID,
|
|
// We LIMIT to the max number of rows here, as OFFSET can't be used without a LIMIT
|
|
maxNumberOfRows,
|
|
maxKeysPerSession,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
// DeleteMk deletes the key with the specified public key and message key
|
|
func (s *sqliteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error {
|
|
stmt, err := s.db.Prepare(`DELETE FROM keys
|
|
WHERE public_key = ? AND msg_num = ?`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
_, err = stmt.Exec(
|
|
pubKey,
|
|
msgNum,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
// Count returns the count of keys with the specified public key
|
|
func (s *sqliteKeysStorage) Count(pubKey dr.Key) (uint, error) {
|
|
stmt, err := s.db.Prepare(`SELECT COUNT(1)
|
|
FROM keys
|
|
WHERE public_key = ?`)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
var count uint
|
|
err = stmt.QueryRow(pubKey).Scan(&count)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// CountAll returns the count of keys with the specified public key
|
|
func (s *sqliteKeysStorage) CountAll() (uint, error) {
|
|
stmt, err := s.db.Prepare(`SELECT COUNT(1)
|
|
FROM keys`)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
var count uint
|
|
err = stmt.QueryRow().Scan(&count)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// All returns nil
|
|
func (s *sqliteKeysStorage) All() (map[string]map[uint]dr.Key, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
type sqliteSessionStorage struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func newSQLiteSessionStorage(db *sql.DB) *sqliteSessionStorage {
|
|
return &sqliteSessionStorage{
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
// Save persists the specified double ratchet state
|
|
func (s *sqliteSessionStorage) Save(id []byte, state *dr.State) error {
|
|
dhr := state.DHr
|
|
dhs := state.DHs
|
|
dhsPublic := dhs.PublicKey()
|
|
dhsPrivate := dhs.PrivateKey()
|
|
pn := state.PN
|
|
step := state.Step
|
|
keysCount := state.KeysCount
|
|
|
|
rootChainKey := state.RootCh.CK
|
|
|
|
sendChainKey := state.SendCh.CK
|
|
sendChainN := state.SendCh.N
|
|
|
|
recvChainKey := state.RecvCh.CK
|
|
recvChainN := state.RecvCh.N
|
|
|
|
stmt, err := s.db.Prepare(`INSERT INTO sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count)
|
|
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
_, err = stmt.Exec(
|
|
id,
|
|
dhr,
|
|
dhsPublic,
|
|
dhsPrivate,
|
|
rootChainKey,
|
|
sendChainKey,
|
|
sendChainN,
|
|
recvChainKey,
|
|
recvChainN,
|
|
pn,
|
|
step,
|
|
keysCount,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
// Load retrieves the double ratchet state for a given ID
|
|
func (s *sqliteSessionStorage) Load(id []byte) (*dr.State, error) {
|
|
stmt, err := s.db.Prepare(`SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count
|
|
FROM sessions
|
|
WHERE id = ?`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer stmt.Close()
|
|
|
|
var (
|
|
dhr []byte
|
|
dhsPublic []byte
|
|
dhsPrivate []byte
|
|
rootChainKey []byte
|
|
sendChainKey []byte
|
|
sendChainN uint
|
|
recvChainKey []byte
|
|
recvChainN uint
|
|
pn uint
|
|
step uint
|
|
keysCount uint
|
|
)
|
|
|
|
err = stmt.QueryRow(id).Scan(
|
|
&dhr,
|
|
&dhsPublic,
|
|
&dhsPrivate,
|
|
&rootChainKey,
|
|
&sendChainKey,
|
|
&sendChainN,
|
|
&recvChainKey,
|
|
&recvChainN,
|
|
&pn,
|
|
&step,
|
|
&keysCount,
|
|
)
|
|
switch err {
|
|
case sql.ErrNoRows:
|
|
return nil, nil
|
|
case nil:
|
|
state := dr.DefaultState(rootChainKey)
|
|
|
|
state.PN = uint32(pn)
|
|
state.Step = step
|
|
state.KeysCount = keysCount
|
|
|
|
state.DHs = crypto.DHPair{
|
|
PrvKey: dhsPrivate,
|
|
PubKey: dhsPublic,
|
|
}
|
|
|
|
state.DHr = dhr
|
|
|
|
state.SendCh.CK = sendChainKey
|
|
state.SendCh.N = uint32(sendChainN)
|
|
|
|
state.RecvCh.CK = recvChainKey
|
|
state.RecvCh.N = uint32(recvChainN)
|
|
|
|
return &state, nil
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
type HRCache struct {
|
|
GroupID []byte
|
|
KeyID []byte
|
|
DeprecatedKeyID uint32
|
|
Key []byte
|
|
Hash []byte
|
|
SeqNo uint32
|
|
}
|
|
|
|
// GetHashRatchetCache retrieves a hash ratchet key by group ID and seqNo.
|
|
// If cache data with given seqNo (e.g. 0) is not found,
|
|
// then the query will return the cache data with the latest seqNo
|
|
func (s *sqlitePersistence) GetHashRatchetCache(ratchet *HashRatchetKeyCompatibility, seqNo uint32) (*HRCache, error) {
|
|
tx, err := s.DB.BeginTx(context.Background(), &sql.TxOptions{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if err == nil {
|
|
err = tx.Commit()
|
|
return
|
|
}
|
|
// don't shadow original error
|
|
_ = tx.Rollback()
|
|
}()
|
|
|
|
var key, keyID []byte
|
|
if !ratchet.IsOldFormat() {
|
|
keyID, err = ratchet.GetKeyID()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
err = tx.QueryRow("SELECT key FROM hash_ratchet_encryption WHERE key_id = ? OR (deprecated_key_id = ? AND group_id = ?)",
|
|
keyID,
|
|
ratchet.DeprecatedKeyID(),
|
|
ratchet.GroupID,
|
|
).Scan(&key)
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
args := make([]interface{}, 0)
|
|
args = append(args, ratchet.GroupID)
|
|
args = append(args, keyID)
|
|
args = append(args, ratchet.DeprecatedKeyID())
|
|
var query string
|
|
if seqNo == 0 {
|
|
query = "SELECT seq_no, hash FROM hash_ratchet_encryption_cache WHERE group_id = ? AND (key_id = ? OR key_id = ?) ORDER BY seq_no DESC limit 1"
|
|
} else {
|
|
query = "SELECT seq_no, hash FROM hash_ratchet_encryption_cache WHERE group_id = ? AND (key_id = ? OR key_id = ?) AND seq_no == ? ORDER BY seq_no DESC limit 1"
|
|
args = append(args, seqNo)
|
|
}
|
|
|
|
var hash []byte
|
|
var seqNoPtr *uint32
|
|
|
|
err = tx.QueryRow(query, args...).Scan(&seqNoPtr, &hash) //nolint: ineffassign,staticcheck
|
|
switch err {
|
|
case sql.ErrNoRows, nil:
|
|
var seqNoResult uint32
|
|
if seqNoPtr == nil {
|
|
seqNoResult = 0
|
|
} else {
|
|
seqNoResult = *seqNoPtr
|
|
}
|
|
|
|
ratchet.Key = key
|
|
keyID, err := ratchet.GetKeyID()
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
res := &HRCache{
|
|
KeyID: keyID,
|
|
Key: key,
|
|
Hash: hash,
|
|
SeqNo: seqNoResult,
|
|
}
|
|
|
|
return res, nil
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
type HashRatchetKeyCompatibility struct {
|
|
GroupID []byte
|
|
keyID []byte
|
|
Timestamp uint64
|
|
Key []byte
|
|
}
|
|
|
|
func (h *HashRatchetKeyCompatibility) DeprecatedKeyID() uint32 {
|
|
return uint32(h.Timestamp)
|
|
}
|
|
|
|
func (h *HashRatchetKeyCompatibility) IsOldFormat() bool {
|
|
return len(h.keyID) == 0 && len(h.Key) == 0
|
|
}
|
|
|
|
func (h *HashRatchetKeyCompatibility) GetKeyID() ([]byte, error) {
|
|
if len(h.keyID) != 0 {
|
|
return h.keyID, nil
|
|
}
|
|
|
|
if len(h.GroupID) == 0 || h.Timestamp == 0 || len(h.Key) == 0 {
|
|
return nil, errors.New("could not create key")
|
|
}
|
|
|
|
return generateHashRatchetKeyID(h.GroupID, h.Timestamp, h.Key), nil
|
|
}
|
|
|
|
func (h *HashRatchetKeyCompatibility) GenerateNext() (*HashRatchetKeyCompatibility, error) {
|
|
|
|
ratchet := &HashRatchetKeyCompatibility{
|
|
GroupID: h.GroupID,
|
|
}
|
|
|
|
// Randomly generate a hash ratchet key
|
|
hrKey, err := crypto.GenerateKey()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
hrKeyBytes := crypto.FromECDSA(hrKey)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
currentTime := GetCurrentTime()
|
|
if h.Timestamp < currentTime {
|
|
ratchet.Timestamp = bumpKeyID(currentTime)
|
|
} else {
|
|
ratchet.Timestamp = h.Timestamp + 1
|
|
}
|
|
|
|
ratchet.Key = hrKeyBytes
|
|
|
|
_, err = ratchet.GetKeyID()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ratchet, nil
|
|
}
|
|
|
|
// GetCurrentKeyForGroup retrieves a key ID for given group ID
|
|
// (with an assumption that key ids are shared in the group, and
|
|
// at any given time there is a single key used)
|
|
func (s *sqlitePersistence) GetCurrentKeyForGroup(groupID []byte) (*HashRatchetKeyCompatibility, error) {
|
|
ratchet := &HashRatchetKeyCompatibility{
|
|
GroupID: groupID,
|
|
}
|
|
|
|
stmt, err := s.DB.Prepare(`SELECT key_id, key_timestamp, key
|
|
FROM hash_ratchet_encryption
|
|
WHERE group_id = ? order by key_timestamp desc limit 1`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
var keyID, key []byte
|
|
var timestamp uint64
|
|
err = stmt.QueryRow(groupID).Scan(&keyID, ×tamp, &key)
|
|
|
|
switch err {
|
|
case sql.ErrNoRows:
|
|
return ratchet, nil
|
|
case nil:
|
|
ratchet.Key = key
|
|
ratchet.Timestamp = timestamp
|
|
_, err = ratchet.GetKeyID()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ratchet, nil
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// GetKeysForGroup retrieves all key IDs for given group ID
|
|
func (s *sqlitePersistence) GetKeysForGroup(groupID []byte) ([]*HashRatchetKeyCompatibility, error) {
|
|
|
|
var ratchets []*HashRatchetKeyCompatibility
|
|
stmt, err := s.DB.Prepare(`SELECT key_id, key_timestamp, key
|
|
FROM hash_ratchet_encryption
|
|
WHERE group_id = ? order by key_timestamp desc`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
rows, err := stmt.Query(groupID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for rows.Next() {
|
|
ratchet := &HashRatchetKeyCompatibility{GroupID: groupID}
|
|
err := rows.Scan(&ratchet.keyID, &ratchet.Timestamp, &ratchet.Key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ratchets = append(ratchets, ratchet)
|
|
}
|
|
|
|
return ratchets, nil
|
|
}
|
|
|
|
// SaveHashRatchetKeyHash saves a hash ratchet key cache data
|
|
func (s *sqlitePersistence) SaveHashRatchetKeyHash(
|
|
ratchet *HashRatchetKeyCompatibility,
|
|
hash []byte,
|
|
seqNo uint32,
|
|
) error {
|
|
|
|
stmt, err := s.DB.Prepare(`INSERT INTO hash_ratchet_encryption_cache(group_id, key_id, hash, seq_no)
|
|
VALUES(?, ?, ?, ?)`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
keyID, err := ratchet.GetKeyID()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = stmt.Exec(ratchet.GroupID, keyID, hash, seqNo)
|
|
|
|
return err
|
|
}
|
|
|
|
// SaveHashRatchetKey saves a hash ratchet key
|
|
func (s *sqlitePersistence) SaveHashRatchetKey(ratchet *HashRatchetKeyCompatibility) error {
|
|
stmt, err := s.DB.Prepare(`INSERT INTO hash_ratchet_encryption(group_id, key_id, key_timestamp, deprecated_key_id, key)
|
|
VALUES(?,?,?,?,?)`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
keyID, err := ratchet.GetKeyID()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = stmt.Exec(ratchet.GroupID, keyID, ratchet.Timestamp, ratchet.DeprecatedKeyID(), ratchet.Key)
|
|
|
|
return err
|
|
}
|
|
|
|
func (s *sqlitePersistence) GetHashRatchetKeyByID(keyID []byte) (*HashRatchetKeyCompatibility, error) {
|
|
ratchet := &HashRatchetKeyCompatibility{
|
|
keyID: keyID,
|
|
}
|
|
|
|
err := s.DB.QueryRow(`
|
|
SELECT group_id, key_timestamp, key
|
|
FROM hash_ratchet_encryption
|
|
WHERE key_id = ?`, keyID).Scan(&ratchet.GroupID, &ratchet.Timestamp, &ratchet.Key)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
return ratchet, nil
|
|
}
|