Change handling of skipped/deleted keys & add version (#1261)

- Skipped keys

The purpose of limiting the number of skipped keys generated is to avoid a dos
attack whereby an attacker would send a large N, forcing the device to
compute all the keys between currentN..N .

Previously the logic for handling skipped keys was:

- If in the current receiving chain there are more than maxSkip keys,
throw an error

This is problematic as in long-lived session dropped/unreceived messages starts
piling up, eventually reaching the threshold (1000 dropped/unreceived
messages).

This logic has been changed to be more inline with signals spec, and now
it is:

- If N is > currentN + maxSkip, throw an error

The purpose of limiting the number of skipped keys stored is to avoid a dos
attack whereby an attacker would force us to store a large number of
keys, filling up our storage.

Previously the logic for handling old keys was:

- Once you have maxKeep ratchet steps, delete any key from
currentRatchet - maxKeep.

This, in combination with the maxSkip implementation, capped the number of stored keys to
maxSkip * maxKeep.

The logic has been changed to:

- Keep a maximum of MaxMessageKeysPerSession

and additionally we delete any key that has a sequence number <
currentSeqNum - maxKeep

- Version

We check now the version of the bundle so that when we get a bundle from
the same installationID with a higher version, we mark the previous
bundle as expired and use the new bundle the next time a message is sent
This commit is contained in:
Andrea Maria Piana 2018-11-05 20:00:04 +01:00 committed by GitHub
parent 58bd36e79e
commit ee3c05c79b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 690 additions and 161 deletions

4
Gopkg.lock generated
View File

@ -782,11 +782,11 @@
version = "v1.1"
[[projects]]
digest = "1:b80bfc6278c27edcb1946fc38bc61f4c3a1cd50978083a4a79b7c875c70ed9d8"
digest = "1:41fb72d7a71f37f1f9c766d965178636ecda21b429b1f2e3fff42cfc31279751"
name = "github.com/status-im/doubleratchet"
packages = ["."]
pruneopts = "NUT"
revision = "321788dbb6eac36f7dab04e631db139e13bb280b"
revision = "4dcb6cba284ae9f97129e2a98b9277f629d9dbc4"
[[projects]]
branch = "master"

View File

@ -160,7 +160,7 @@
[[constraint]]
name = "github.com/status-im/doubleratchet"
revision = "321788dbb6eac36f7dab04e631db139e13bb280b"
revision = "4dcb6cba284ae9f97129e2a98b9277f629d9dbc4"
[[constraint]]
name = "github.com/status-im/migrate"

View File

@ -21,6 +21,15 @@ var ErrSessionNotFound = errors.New("session not found")
// If we have no bundles, we use a constant so that the message can reach any device
const noInstallationID = "none"
// How many consecutive messages can be skipped in the receiving chain
const maxSkip = 1000
// Any message with seqNo <= currentSeq - maxKeep will be deleted
const maxKeep = 3000
// How many keys do we store in total per session
const maxMessageKeysPerSession = 2000
// EncryptionService defines a service that is responsible for the encryption aspect of the protocol
type EncryptionService struct {
log log.Logger
@ -251,9 +260,9 @@ func (s *EncryptionService) createNewSession(drInfo *RatchetInfo, sk [32]byte, k
keyPair,
s.persistence.GetSessionStorage(),
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
// TODO: Temporarily increase to a high number, until
// we make sure it's a sliding window rather than dropping
dr.WithMaxSkip(10000),
dr.WithMaxSkip(maxSkip),
dr.WithMaxKeep(maxKeep),
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
dr.WithCrypto(crypto.EthereumCrypto{}))
} else {
session, err = dr.NewWithRemoteKey(
@ -262,9 +271,9 @@ func (s *EncryptionService) createNewSession(drInfo *RatchetInfo, sk [32]byte, k
keyPair.PubKey,
s.persistence.GetSessionStorage(),
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
// TODO: Temporarily increase to a high number, until
// we make sure it's a sliding window rather than dropping
dr.WithMaxSkip(10000),
dr.WithMaxSkip(maxSkip),
dr.WithMaxKeep(maxKeep),
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
dr.WithCrypto(crypto.EthereumCrypto{}))
}
@ -291,9 +300,9 @@ func (s *EncryptionService) encryptUsingDR(theirIdentityKey *ecdsa.PublicKey, dr
drInfo.ID,
sessionStorage,
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
// TODO: Temporarily increase to a high number, until
// we make sure it's a sliding window rather than dropping
dr.WithMaxSkip(10000),
dr.WithMaxSkip(maxSkip),
dr.WithMaxKeep(maxKeep),
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
dr.WithCrypto(crypto.EthereumCrypto{}),
)
if err != nil {
@ -342,9 +351,9 @@ func (s *EncryptionService) decryptUsingDR(theirIdentityKey *ecdsa.PublicKey, dr
drInfo.ID,
sessionStorage,
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
// TODO: Temporarily increase to a high number, until
// we make sure it's a sliding window rather than dropping
dr.WithMaxSkip(10000),
dr.WithMaxSkip(maxSkip),
dr.WithMaxKeep(maxKeep),
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
dr.WithCrypto(crypto.EthereumCrypto{}),
)
if err != nil {

View File

@ -314,6 +314,213 @@ func (s *EncryptionServiceTestSuite) TestConversation() {
s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH")
}
// Previous implementation allowed max maxSkip keys in the same receiving chain
// leading to a problem whereby dropped messages would accumulate and eventually
// we would not be able to decrypt any new message anymore.
// Here we are testing that maxSkip only applies to *consecutive* messages, not
// overall.
func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() {
bobText := []byte("text")
bobKey, err := crypto.GenerateKey()
s.Require().NoError(err)
aliceKey, err := crypto.GenerateKey()
s.Require().NoError(err)
// Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey)
s.Require().NoError(err)
// We add bob bundle
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle)
s.Require().NoError(err)
// Create a bundle
aliceBundle, err := s.alice.CreateBundle(aliceKey)
s.Require().NoError(err)
// We add alice bundle
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
s.Require().NoError(err)
// Bob sends a message
for i := 0; i < maxSkip; i++ {
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
s.Require().NoError(err)
}
// Bob sends a message
bobMessage1, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
s.Require().NoError(err)
// Alice receives the message
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1)
s.Require().NoError(err)
// Bob sends a message
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
s.Require().NoError(err)
// Bob sends a message
bobMessage2, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
s.Require().NoError(err)
// Alice receives the message, we should have maxSkip + 1 keys in the db, but
// we should not throw an error
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2)
s.Require().NoError(err)
}
// Test that an error is thrown if max skip is reached
func (s *EncryptionServiceTestSuite) TestMaxSkipKeysError() {
bobText := []byte("text")
bobKey, err := crypto.GenerateKey()
s.Require().NoError(err)
aliceKey, err := crypto.GenerateKey()
s.Require().NoError(err)
// Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey)
s.Require().NoError(err)
// We add bob bundle
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle)
s.Require().NoError(err)
// Create a bundle
aliceBundle, err := s.alice.CreateBundle(aliceKey)
s.Require().NoError(err)
// We add alice bundle
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
s.Require().NoError(err)
// Bob sends a message
for i := 0; i < maxSkip+1; i++ {
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
s.Require().NoError(err)
}
// Bob sends a message
bobMessage1, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
s.Require().NoError(err)
// Alice receives the message
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1)
s.Require().Equal(errors.New("can't skip current chain message keys: too many messages"), err)
}
func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() {
bobText := []byte("text")
bobKey, err := crypto.GenerateKey()
s.Require().NoError(err)
aliceKey, err := crypto.GenerateKey()
s.Require().NoError(err)
// Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey)
s.Require().NoError(err)
// We add bob bundle
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle)
s.Require().NoError(err)
// Create a bundle
aliceBundle, err := s.alice.CreateBundle(aliceKey)
s.Require().NoError(err)
// We add alice bundle
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
s.Require().NoError(err)
// We create just enough messages so that the first key should be deleted
nMessages := maxMessageKeysPerSession + maxMessageKeysPerSession/maxSkip + 2
messages := make([]map[string]*DirectMessageProtocol, nMessages)
for i := 0; i < nMessages; i++ {
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
s.Require().NoError(err)
messages[i] = m
// We decrypt some messages otherwise we hit maxSkip limit
if i%maxSkip == 0 {
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
s.Require().NoError(err)
}
}
// Another message to trigger the deletion
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
s.Require().NoError(err)
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
s.Require().NoError(err)
// We decrypt the first message, and it should fail
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1])
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
// We decrypt the second message, and it should be decrypted
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[2])
s.Require().NoError(err)
}
func (s *EncryptionServiceTestSuite) TestMaxKeep() {
bobText := []byte("text")
bobKey, err := crypto.GenerateKey()
s.Require().NoError(err)
aliceKey, err := crypto.GenerateKey()
s.Require().NoError(err)
// Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey)
s.Require().NoError(err)
// We add bob bundle
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle)
s.Require().NoError(err)
// Create a bundle
aliceBundle, err := s.alice.CreateBundle(aliceKey)
s.Require().NoError(err)
// We add alice bundle
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
s.Require().NoError(err)
// We decrypt all messages but 1 & 2
messages := make([]map[string]*DirectMessageProtocol, maxKeep)
for i := 0; i < maxKeep; i++ {
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
messages[i] = m
s.Require().NoError(err)
if i != 0 && i != 1 {
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
s.Require().NoError(err)
}
}
// We decrypt the first message, and it should fail, as it should have been removed
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[0])
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
// We decrypt the second message, and it should be decrypted
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1])
s.Require().NoError(err)
}
// Alice has Bob's bundle
// Bob has Alice's bundle
// Bob sends a message to alice
@ -557,6 +764,9 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
bobBundle2, err := NewBundleContainer(bobKey, bobInstallationID)
s.Require().NoError(err)
// We set the version
bobBundle2.GetBundle().GetSignedPreKeys()[bobInstallationID].Version = 1
err = SignBundle(bobKey, bobBundle2)
s.Require().NoError(err)

View File

@ -4,6 +4,8 @@
// 1536754952_initial_schema.up.sql
// 1539249977_update_ratchet_info.down.sql
// 1539249977_update_ratchet_info.up.sql
// 1540715431_add_version.down.sql
// 1540715431_add_version.up.sql
// static.go
// DO NOT EDIT!
@ -87,7 +89,7 @@ func _1536754952_initial_schemaDownSql() (*asset, error) {
return nil, err
}
info := bindataFileInfo{name: "1536754952_initial_schema.down.sql", size: 83, mode: os.FileMode(420), modTime: time.Unix(1537862328, 0)}
info := bindataFileInfo{name: "1536754952_initial_schema.down.sql", size: 83, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
@ -107,7 +109,7 @@ func _1536754952_initial_schemaUpSql() (*asset, error) {
return nil, err
}
info := bindataFileInfo{name: "1536754952_initial_schema.up.sql", size: 962, mode: os.FileMode(420), modTime: time.Unix(1539252806, 0)}
info := bindataFileInfo{name: "1536754952_initial_schema.up.sql", size: 962, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
@ -127,7 +129,7 @@ func _1539249977_update_ratchet_infoDownSql() (*asset, error) {
return nil, err
}
info := bindataFileInfo{name: "1539249977_update_ratchet_info.down.sql", size: 311, mode: os.FileMode(420), modTime: time.Unix(1539250187, 0)}
info := bindataFileInfo{name: "1539249977_update_ratchet_info.down.sql", size: 311, mode: os.FileMode(420), modTime: time.Unix(1540738831, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
@ -147,7 +149,47 @@ func _1539249977_update_ratchet_infoUpSql() (*asset, error) {
return nil, err
}
info := bindataFileInfo{name: "1539249977_update_ratchet_info.up.sql", size: 368, mode: os.FileMode(420), modTime: time.Unix(1539250201, 0)}
info := bindataFileInfo{name: "1539249977_update_ratchet_info.up.sql", size: 368, mode: os.FileMode(420), modTime: time.Unix(1540738831, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
var __1540715431_add_versionDownSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x72\xf4\x09\x71\x0d\x52\x08\x71\x74\xf2\x71\x55\xc8\x4e\xad\x2c\x56\x70\x09\xf2\x0f\x50\x70\xf6\xf7\x09\xf5\xf5\x53\x28\x4e\x2d\x2e\xce\xcc\xcf\x8b\xcf\x4c\xb1\xe6\x42\x56\x08\x15\x47\x55\x0c\xd2\x1d\x9f\x9c\x5f\x9a\x57\x82\xaa\x38\xa9\x34\x2f\x25\x27\x15\x55\x6d\x59\x6a\x11\xc8\x00\x6b\x2e\x40\x00\x00\x00\xff\xff\xda\x5d\x80\x2d\x7f\x00\x00\x00")
func _1540715431_add_versionDownSqlBytes() ([]byte, error) {
return bindataRead(
__1540715431_add_versionDownSql,
"1540715431_add_version.down.sql",
)
}
func _1540715431_add_versionDownSql() (*asset, error) {
bytes, err := _1540715431_add_versionDownSqlBytes()
if err != nil {
return nil, err
}
info := bindataFileInfo{name: "1540715431_add_version.down.sql", size: 127, mode: os.FileMode(420), modTime: time.Unix(1540989119, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
var __1540715431_add_versionUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x8c\xcd\xb1\x0e\x02\x21\x0c\xc6\xf1\xdd\xa7\xf8\x1e\xc1\xdd\x09\xa4\x67\x4c\x7a\x90\x90\x32\x93\xe8\x31\x5c\x54\x2e\x8a\x98\xf8\xf6\x06\xe3\xc2\xa2\xae\x6d\xff\xbf\x1a\x62\x12\xc2\xe0\xdd\x88\x53\x7a\x96\xcd\x4a\xb1\x90\x87\x28\xcd\xf4\x9e\x40\x19\x83\xad\xe3\x30\x5a\x94\x74\x8d\xb9\x5e\xb0\xb7\x42\x3b\xf2\xb0\x4e\x60\x03\x33\x0c\x0d\x2a\xb0\x60\xfd\xab\x2f\x65\x5e\x72\x9c\x27\x68\x76\xba\x3f\xfe\x2c\xbb\xa0\x01\xf1\xb8\xd4\x7c\xff\xfb\xe7\xa1\xe6\xe9\x9c\x3a\xe5\x91\x6e\x4d\xfe\x4a\xbc\x02\x00\x00\xff\xff\x0e\x27\x2c\x52\x09\x01\x00\x00")
func _1540715431_add_versionUpSqlBytes() ([]byte, error) {
return bindataRead(
__1540715431_add_versionUpSql,
"1540715431_add_version.up.sql",
)
}
func _1540715431_add_versionUpSql() (*asset, error) {
bytes, err := _1540715431_add_versionUpSqlBytes()
if err != nil {
return nil, err
}
info := bindataFileInfo{name: "1540715431_add_version.up.sql", size: 265, mode: os.FileMode(420), modTime: time.Unix(1540989075, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
@ -167,7 +209,7 @@ func staticGo() (*asset, error) {
return nil, err
}
info := bindataFileInfo{name: "static.go", size: 188, mode: os.FileMode(420), modTime: time.Unix(1537862328, 0)}
info := bindataFileInfo{name: "static.go", size: 188, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
@ -228,6 +270,8 @@ var _bindata = map[string]func() (*asset, error){
"1536754952_initial_schema.up.sql": _1536754952_initial_schemaUpSql,
"1539249977_update_ratchet_info.down.sql": _1539249977_update_ratchet_infoDownSql,
"1539249977_update_ratchet_info.up.sql": _1539249977_update_ratchet_infoUpSql,
"1540715431_add_version.down.sql": _1540715431_add_versionDownSql,
"1540715431_add_version.up.sql": _1540715431_add_versionUpSql,
"static.go": staticGo,
}
@ -275,6 +319,8 @@ var _bintree = &bintree{nil, map[string]*bintree{
"1536754952_initial_schema.up.sql": &bintree{_1536754952_initial_schemaUpSql, map[string]*bintree{}},
"1539249977_update_ratchet_info.down.sql": &bintree{_1539249977_update_ratchet_infoDownSql, map[string]*bintree{}},
"1539249977_update_ratchet_info.up.sql": &bintree{_1539249977_update_ratchet_infoUpSql, map[string]*bintree{}},
"1540715431_add_version.down.sql": &bintree{_1540715431_add_versionDownSql, map[string]*bintree{}},
"1540715431_add_version.up.sql": &bintree{_1540715431_add_versionUpSql, map[string]*bintree{}},
"static.go": &bintree{staticGo, map[string]*bintree{}},
}}

View File

@ -61,8 +61,7 @@ func (p *ProtocolService) BuildPublicMessage(myIdentityKey *ecdsa.PrivateKey, pa
// BuildDirectMessage marshals a 1:1 chat message given the user identity private key, the recipient's public key, and a payload
func (p *ProtocolService) BuildDirectMessage(myIdentityKey *ecdsa.PrivateKey, payload []byte, theirPublicKeys ...*ecdsa.PublicKey) (map[*ecdsa.PublicKey][]byte, error) {
response := make(map[*ecdsa.PublicKey][]byte)
publicKeys := append(theirPublicKeys, &myIdentityKey.PublicKey)
for _, publicKey := range publicKeys {
for _, publicKey := range theirPublicKeys {
// Encrypt payload
encryptionResponse, err := p.encryption.EncryptPayload(publicKey, myIdentityKey, payload)
if err != nil {

View File

@ -56,7 +56,7 @@ func (s *ProtocolServiceTestSuite) TestBuildDirectMessage() {
})
s.NoError(err)
marshaledMsg, err := s.alice.BuildDirectMessage(aliceKey, payload, &bobKey.PublicKey)
marshaledMsg, err := s.alice.BuildDirectMessage(aliceKey, payload, &bobKey.PublicKey, &aliceKey.PublicKey)
s.NoError(err)
s.NotNil(marshaledMsg, "It creates a message")

View File

@ -16,6 +16,9 @@ import (
"github.com/status-im/status-go/services/shhext/chat/migrations"
)
// A safe max number of rows
const maxNumberOfRows = 100000000
// SQLLitePersistence represents a persistence service tied to an SQLite database
type SQLLitePersistence struct {
db *sql.DB
@ -107,7 +110,20 @@ func (s *SQLLitePersistence) AddPrivateBundle(b *BundleContainer) error {
}
for installationID, signedPreKey := range b.GetBundle().GetSignedPreKeys() {
stmt, err := tx.Prepare("INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, timestamp) VALUES(?, ?, ?, ?, ?)")
var version uint32
stmt, err := tx.Prepare("SELECT version FROM bundles WHERE installation_id = ? AND identity = ? ORDER BY version DESC LIMIT 1")
if err != nil {
return err
}
defer stmt.Close()
err = stmt.QueryRow(installationID, b.GetBundle().GetIdentity()).Scan(&version)
if err != nil && err != sql.ErrNoRows {
return err
}
stmt, err = tx.Prepare("INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, version, timestamp) VALUES(?, ?, ?, ?, ?, ?)")
if err != nil {
return err
}
@ -118,6 +134,7 @@ func (s *SQLLitePersistence) AddPrivateBundle(b *BundleContainer) error {
b.GetPrivateSignedPreKey(),
signedPreKey.GetSignedPreKey(),
installationID,
version+1,
time.Now().UnixNano(),
)
if err != nil {
@ -144,15 +161,18 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
for installationID, signedPreKeyContainer := range b.GetSignedPreKeys() {
signedPreKey := signedPreKeyContainer.GetSignedPreKey()
insertStmt, err := tx.Prepare("INSERT INTO bundles(identity, signed_pre_key, installation_id, timestamp) VALUES( ?, ?, ?, ?)")
version := signedPreKeyContainer.GetVersion()
insertStmt, err := tx.Prepare("INSERT INTO bundles(identity, signed_pre_key, installation_id, version, timestamp) VALUES( ?, ?, ?, ?, ?)")
if err != nil {
return err
}
defer insertStmt.Close()
_, err = insertStmt.Exec(
b.GetIdentity(),
signedPreKey,
installationID,
version,
time.Now().UnixNano(),
)
if err != nil {
@ -160,7 +180,7 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
return err
}
// Mark old bundles as expired
updateStmt, err := tx.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND installation_id = ? AND signed_pre_key != ?")
updateStmt, err := tx.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND installation_id = ? AND version < ?")
if err != nil {
return err
}
@ -169,7 +189,7 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
_, err = updateStmt.Exec(
b.GetIdentity(),
installationID,
signedPreKey,
version,
)
if err != nil {
_ = tx.Rollback()
@ -281,7 +301,7 @@ func (s *SQLLitePersistence) MarkBundleExpired(identity []byte) error {
func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey) (*Bundle, error) {
identity := crypto.CompressPubkey(publicKey)
stmt, err := s.db.Prepare("SELECT signed_pre_key,installation_id FROM bundles WHERE expired = 0 AND identity = ? ORDER BY timestamp DESC")
stmt, err := s.db.Prepare("SELECT signed_pre_key,installation_id, version FROM bundles WHERE expired = 0 AND identity = ? ORDER BY version DESC")
if err != nil {
return nil, err
}
@ -304,16 +324,21 @@ func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey) (*Bundl
for rows.Next() {
var signedPreKey []byte
var installationID string
var version uint32
rowCount++
err = rows.Scan(
&signedPreKey,
&installationID,
&version,
)
if err != nil {
return nil, err
}
bundle.SignedPreKeys[installationID] = &SignedPreKey{SignedPreKey: signedPreKey}
bundle.SignedPreKeys[installationID] = &SignedPreKey{
SignedPreKey: signedPreKey,
Version: version,
}
}
@ -448,17 +473,53 @@ func (s *SQLLiteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, erro
}
// Put stores a key with the specified public key, message number and message key
func (s *SQLLiteKeysStorage) Put(pubKey dr.Key, msgNum uint, mk dr.Key) error {
stmt, err := s.db.Prepare("insert into keys(public_key, msg_num, message_key) values(?, ?, ?)")
func (s *SQLLiteKeysStorage) Put(sessionID []byte, pubKey dr.Key, msgNum uint, mk dr.Key, seqNum uint) error {
stmt, err := s.db.Prepare("insert into keys(session_id, public_key, msg_num, message_key, seq_num) values(?, ?, ?, ?, ?)")
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(
sessionID,
pubKey[:],
msgNum,
mk[:],
seqNum,
)
return err
}
// DeleteOldMks caps remove any key < seq_num, included
func (s *SQLLiteKeysStorage) DeleteOldMks(sessionID []byte, deleteUntil uint) error {
stmt, err := s.db.Prepare("DELETE FROM keys WHERE session_id = ? AND seq_num <= ?")
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(
sessionID,
deleteUntil,
)
return err
}
// TruncateMks caps the number of keys to maxKeysPerSession deleting them in FIFO fashion
func (s *SQLLiteKeysStorage) TruncateMks(sessionID []byte, maxKeysPerSession int) error {
stmt, err := s.db.Prepare("DELETE FROM keys WHERE rowid IN (SELECT rowid FROM keys WHERE session_id = ? ORDER BY seq_num DESC LIMIT ? OFFSET ?)")
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(
sessionID,
// We LIMIT to the max number of rows here, as OFFSET can't be used without a LIMIT
maxNumberOfRows,
maxKeysPerSession,
)
return err
@ -480,21 +541,6 @@ func (s *SQLLiteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error {
return err
}
// DeletePk deletes the keys with the specified public key
func (s *SQLLiteKeysStorage) DeletePk(pubKey dr.Key) error {
stmt, err := s.db.Prepare("DELETE FROM keys WHERE public_key = ?")
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(
pubKey[:],
)
return err
}
// Count returns the count of keys with the specified public key
func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) {
stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys WHERE public_key = ?")
@ -512,6 +558,23 @@ func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) {
return count, nil
}
// CountAll returns the count of keys with the specified public key
func (s *SQLLiteKeysStorage) CountAll() (uint, error) {
stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys")
if err != nil {
return 0, err
}
defer stmt.Close()
var count uint
err = stmt.QueryRow().Scan(&count)
if err != nil {
return 0, err
}
return count, nil
}
// All returns nil
func (s *SQLLiteKeysStorage) All() (map[dr.Key]map[uint]dr.Key, error) {
return nil, nil
@ -525,6 +588,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
dhsPrivate := dhs.PrivateKey()
pn := state.PN
step := state.Step
keysCount := state.KeysCount
rootChainKey := state.RootCh.CK[:]
@ -534,7 +598,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
recvChainKey := state.RecvCh.CK[:]
recvChainN := state.RecvCh.N
stmt, err := s.db.Prepare("insert into sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
stmt, err := s.db.Prepare("insert into sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
if err != nil {
return err
}
@ -552,6 +616,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
recvChainN,
pn,
step,
keysCount,
)
return err
@ -559,7 +624,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
// Load retrieves the double ratchet state for a given ID
func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
stmt, err := s.db.Prepare("SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step FROM sessions WHERE id = ?")
stmt, err := s.db.Prepare("SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count FROM sessions WHERE id = ?")
if err != nil {
return nil, err
}
@ -577,6 +642,7 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
recvChainN uint
pn uint
step uint
keysCount uint
)
err = stmt.QueryRow(id).Scan(
@ -590,6 +656,7 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
&recvChainN,
&pn,
&step,
&keysCount,
)
switch err {
case sql.ErrNoRows:
@ -599,6 +666,7 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
state.PN = uint32(pn)
state.Step = step
state.KeysCount = keysCount
state.DHs = ecrypto.DHPair{
PrvKey: toKey(dhsPrivate),

View File

@ -10,8 +10,11 @@ import (
var (
pubKey1 = dr.Key{0xe3, 0xbe, 0xb9, 0x4e, 0x70, 0x17, 0x37, 0xc, 0x1, 0x8f, 0xa9, 0x7e, 0xef, 0x4, 0xfb, 0x23, 0xac, 0xea, 0x28, 0xf7, 0xa9, 0x56, 0xcc, 0x1d, 0x46, 0xf3, 0xb5, 0x1d, 0x7d, 0x7d, 0x5e, 0x2c}
pubKey2 = dr.Key{0xec, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
mk1 = dr.Key{0xeb, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
mk2 = dr.Key{0xed, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
mk1 = dr.Key{0x00, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
mk2 = dr.Key{0x01, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
mk3 = dr.Key{0x02, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
mk4 = dr.Key{0x03, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
mk5 = dr.Key{0x04, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
)
func TestSQLLitePersistenceKeysStorageTestSuite(t *testing.T) {
@ -47,10 +50,84 @@ func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLiteGetMissin
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Put() {
// Act and assert.
err := s.service.Put(pubKey1, 0, mk1)
err := s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1)
s.NoError(err)
}
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_DeleteOldMks() {
// Insert keys out-of-order
err := s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1)
s.NoError(err)
err = s.service.Put([]byte("session-id"), pubKey1, 1, mk2, 2)
s.NoError(err)
err = s.service.Put([]byte("session-id"), pubKey1, 2, mk3, 20)
s.NoError(err)
err = s.service.Put([]byte("session-id"), pubKey1, 3, mk4, 21)
s.NoError(err)
err = s.service.Put([]byte("session-id"), pubKey1, 4, mk5, 22)
s.NoError(err)
err = s.service.DeleteOldMks([]byte("session-id"), 20)
s.NoError(err)
_, ok, err := s.service.Get(pubKey1, 0)
s.NoError(err)
s.False(ok)
_, ok, err = s.service.Get(pubKey1, 1)
s.NoError(err)
s.False(ok)
_, ok, err = s.service.Get(pubKey1, 2)
s.NoError(err)
s.False(ok)
_, ok, err = s.service.Get(pubKey1, 3)
s.NoError(err)
s.True(ok)
_, ok, err = s.service.Get(pubKey1, 4)
s.NoError(err)
s.True(ok)
}
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_TruncateMks() {
// Insert keys out-of-order
err := s.service.Put([]byte("session-id"), pubKey2, 2, mk5, 5)
s.NoError(err)
err = s.service.Put([]byte("session-id"), pubKey2, 0, mk3, 3)
s.NoError(err)
err = s.service.Put([]byte("session-id"), pubKey1, 1, mk2, 2)
s.NoError(err)
err = s.service.Put([]byte("session-id"), pubKey2, 1, mk4, 4)
s.NoError(err)
err = s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1)
s.NoError(err)
err = s.service.TruncateMks([]byte("session-id"), 2)
s.NoError(err)
_, ok, err := s.service.Get(pubKey1, 0)
s.NoError(err)
s.False(ok)
_, ok, err = s.service.Get(pubKey1, 1)
s.NoError(err)
s.False(ok)
_, ok, err = s.service.Get(pubKey2, 0)
s.NoError(err)
s.False(ok)
_, ok, err = s.service.Get(pubKey2, 1)
s.NoError(err)
s.True(ok)
_, ok, err = s.service.Get(pubKey2, 2)
s.NoError(err)
s.True(ok)
}
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Count() {
// Act.
@ -71,12 +148,8 @@ func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Delete()
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Flow() {
// Act and assert.
err := s.service.DeletePk(pubKey1)
s.NoError(err)
// Act.
err = s.service.Put(pubKey1, 0, mk1)
err := s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1)
s.NoError(err)
k, ok, err := s.service.Get(pubKey1, 0)
@ -124,30 +197,4 @@ func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Flow() {
// Assert.
s.NoError(err)
s.EqualValues(0, cnt)
// Act.
err = s.service.Put(pubKey1, 0, mk1)
s.NoError(err)
err = s.service.Put(pubKey2, 0, mk2)
s.NoError(err)
err = s.service.DeletePk(pubKey1)
s.NoError(err)
err = s.service.DeletePk(pubKey1)
s.NoError(err)
err = s.service.DeletePk(pubKey2)
s.NoError(err)
cn1, err := s.service.Count(pubKey1)
s.NoError(err)
cn2, err := s.service.Count(pubKey2)
s.NoError(err)
// Assert.
s.Empty(cn1)
s.Empty(cn2)
}

View File

@ -51,7 +51,7 @@ func (s *SQLLitePersistenceTestSuite) TestPrivateBundle() {
s.Require().NoError(err)
actualKey, err := s.service.GetPrivateKeyBundle([]byte("non-existing"))
s.Require().NoError(err, "It does not return an error if the bundle is not there")
s.Require().NoError(err, "Error was not returned even though bundle is not there")
s.Nil(actualKey)
anyPrivateBundle, err := s.service.GetAnyPrivateBundle([]byte("non-existing-id"))
@ -82,7 +82,7 @@ func (s *SQLLitePersistenceTestSuite) TestPublicBundle() {
s.Require().NoError(err)
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
s.Require().NoError(err, "It does not return an error if the bundle is not there")
s.Require().NoError(err, "Error was not returned even though bundle is not there")
s.Nil(actualBundle)
bundleContainer, err := NewBundleContainer(key, "1")
@ -98,12 +98,82 @@ func (s *SQLLitePersistenceTestSuite) TestPublicBundle() {
s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys")
}
func (s *SQLLitePersistenceTestSuite) TestUpdatedBundle() {
key, err := crypto.GenerateKey()
s.Require().NoError(err)
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
s.Require().NoError(err, "Error was not returned even though bundle is not there")
s.Nil(actualBundle)
// Create & add initial bundle
bundleContainer, err := NewBundleContainer(key, "1")
s.Require().NoError(err)
bundle := bundleContainer.GetBundle()
err = s.service.AddPublicBundle(bundle)
s.Require().NoError(err)
// Create & add a new bundle
bundleContainer, err = NewBundleContainer(key, "1")
s.Require().NoError(err)
bundle = bundleContainer.GetBundle()
// We set the version
bundle.GetSignedPreKeys()["1"].Version = 1
err = s.service.AddPublicBundle(bundle)
s.Require().NoError(err)
actualBundle, err = s.service.GetPublicBundle(&key.PublicKey)
s.Require().NoError(err)
s.Equal(bundle.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity")
s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys")
}
func (s *SQLLitePersistenceTestSuite) TestOutOfOrderBundles() {
key, err := crypto.GenerateKey()
s.Require().NoError(err)
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
s.Require().NoError(err, "Error was not returned even though bundle is not there")
s.Nil(actualBundle)
// Create & add initial bundle
bundleContainer, err := NewBundleContainer(key, "1")
s.Require().NoError(err)
bundle1 := bundleContainer.GetBundle()
err = s.service.AddPublicBundle(bundle1)
s.Require().NoError(err)
// Create & add a new bundle
bundleContainer, err = NewBundleContainer(key, "1")
s.Require().NoError(err)
bundle2 := bundleContainer.GetBundle()
// We set the version
bundle2.GetSignedPreKeys()["1"].Version = 1
err = s.service.AddPublicBundle(bundle2)
s.Require().NoError(err)
// Add again the initial bundle
err = s.service.AddPublicBundle(bundle1)
s.Require().NoError(err)
actualBundle, err = s.service.GetPublicBundle(&key.PublicKey)
s.Require().NoError(err)
s.Equal(bundle2.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity")
s.Equal(bundle2.GetSignedPreKeys()["1"].GetVersion(), uint32(1))
s.Equal(bundle2.GetSignedPreKeys()["1"].GetSignedPreKey(), actualBundle.GetSignedPreKeys()["1"].GetSignedPreKey(), "It sets the right prekeys")
}
func (s *SQLLitePersistenceTestSuite) TestMultiplePublicBundle() {
key, err := crypto.GenerateKey()
s.Require().NoError(err)
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
s.Require().NoError(err, "It does not return an error if the bundle is not there")
s.Require().NoError(err, "Error was not returned even though bundle is not there")
s.Nil(actualBundle)
bundleContainer, err := NewBundleContainer(key, "1")
@ -120,8 +190,10 @@ func (s *SQLLitePersistenceTestSuite) TestMultiplePublicBundle() {
// Adding a different bundle
bundleContainer, err = NewBundleContainer(key, "1")
s.Require().NoError(err)
// We set the version
bundle = bundleContainer.GetBundle()
bundle.GetSignedPreKeys()["1"].Version = 1
err = s.service.AddPublicBundle(bundle)
s.Require().NoError(err)
@ -139,7 +211,7 @@ func (s *SQLLitePersistenceTestSuite) TestMultiDevicePublicBundle() {
s.Require().NoError(err)
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
s.Require().NoError(err, "It does not return an error if the bundle is not there")
s.Require().NoError(err, "Error was not returned even though bundle is not there")
s.Nil(actualBundle)
bundleContainer, err := NewBundleContainer(key, "1")
@ -273,4 +345,3 @@ func (s *SQLLitePersistenceTestSuite) TestRatchetInfoNoBundle() {
}
// TODO: Add test for MarkBundleExpired
// TODO: Add test for AddPublicBundle checking that it expires previous bundles

View File

@ -88,7 +88,7 @@ func ConfigCliFleetEthBetaJson() (*asset, error) {
return nil, err
}
info := bindataFileInfo{name: "../config/cli/fleet-eth.beta.json", size: 3237, mode: os.FileMode(420), modTime: time.Unix(1537514234, 0)}
info := bindataFileInfo{name: "../config/cli/fleet-eth.beta.json", size: 3237, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
@ -108,7 +108,7 @@ func ConfigCliFleetEthStagingJson() (*asset, error) {
return nil, err
}
info := bindataFileInfo{name: "../config/cli/fleet-eth.staging.json", size: 1838, mode: os.FileMode(420), modTime: time.Unix(1537514234, 0)}
info := bindataFileInfo{name: "../config/cli/fleet-eth.staging.json", size: 1838, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
@ -128,7 +128,7 @@ func ConfigCliFleetEthTestJson() (*asset, error) {
return nil, err
}
info := bindataFileInfo{name: "../config/cli/fleet-eth.test.json", size: 1519, mode: os.FileMode(420), modTime: time.Unix(1537514234, 0)}
info := bindataFileInfo{name: "../config/cli/fleet-eth.test.json", size: 1519, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
@ -148,7 +148,7 @@ func ConfigCliLesEnabledJson() (*asset, error) {
return nil, err
}
info := bindataFileInfo{name: "../config/cli/les-enabled.json", size: 58, mode: os.FileMode(420), modTime: time.Unix(1536858252, 0)}
info := bindataFileInfo{name: "../config/cli/les-enabled.json", size: 58, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
@ -168,7 +168,7 @@ func ConfigCliMailserverEnabledJson() (*asset, error) {
return nil, err
}
info := bindataFileInfo{name: "../config/cli/mailserver-enabled.json", size: 176, mode: os.FileMode(420), modTime: time.Unix(1538032850, 0)}
info := bindataFileInfo{name: "../config/cli/mailserver-enabled.json", size: 176, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
@ -188,7 +188,7 @@ func ConfigStatusChainGenesisJson() (*asset, error) {
return nil, err
}
info := bindataFileInfo{name: "../config/status-chain-genesis.json", size: 612, mode: os.FileMode(420), modTime: time.Unix(1536858252, 0)}
info := bindataFileInfo{name: "../config/status-chain-genesis.json", size: 612, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}

View File

@ -0,0 +1,3 @@
ALTER TABLE keys DROP COLUMN session_id;
ALTER TABLE sessions DROP COLUMN keys_count;
ALTER TABLE bundles DROP COLUMN version;

View File

@ -0,0 +1,5 @@
DELETE FROM keys;
ALTER TABLE keys ADD COLUMN seq_num INTEGER NOT NULL DEFAULT 0;
ALTER TABLE keys ADD COLUMN session_id BLOB;
ALTER TABLE sessions ADD COLUMN keys_count INTEGER NOT NULL DEFAULT 0;
ALTER TABLE bundles ADD COLUMN version INTEGER NOT NULL DEFAULT 0;

View File

@ -1,18 +1,26 @@
package doubleratchet
import (
"bytes"
"sort"
)
// KeysStorage is an interface of an abstract in-memory or persistent keys storage.
type KeysStorage interface {
// Get returns a message key by the given key and message number.
Get(k Key, msgNum uint) (mk Key, ok bool, err error)
// Put saves the given mk under the specified key and msgNum.
Put(k Key, msgNum uint, mk Key) error
Put(sessionID []byte, k Key, msgNum uint, mk Key, keySeqNum uint) error
// DeleteMk ensures there's no message key under the specified key and msgNum.
DeleteMk(k Key, msgNum uint) error
// DeletePk ensures there's no message keys under the specified key.
DeletePk(k Key) error
// DeleteOldMKeys deletes old message keys for a session.
DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error
// TruncateMks truncates the number of keys to maxKeys.
TruncateMks(sessionID []byte, maxKeys int) error
// Count returns number of message keys stored under the specified key.
Count(k Key) (uint, error)
@ -23,10 +31,10 @@ type KeysStorage interface {
// KeysStorageInMemory is an in-memory message keys storage.
type KeysStorageInMemory struct {
keys map[Key]map[uint]Key
keys map[Key]map[uint]InMemoryKey
}
// See KeysStorage.
// Get returns a message key by the given key and message number.
func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) {
if s.keys == nil {
return Key{}, false, nil
@ -39,22 +47,32 @@ func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) {
if !ok {
return Key{}, false, nil
}
return mk, true, nil
return mk.messageKey, true, nil
}
// See KeysStorage.
func (s *KeysStorageInMemory) Put(pubKey Key, msgNum uint, mk Key) error {
type InMemoryKey struct {
messageKey Key
seqNum uint
sessionID []byte
}
// Put saves the given mk under the specified key and msgNum.
func (s *KeysStorageInMemory) Put(sessionID []byte, pubKey Key, msgNum uint, mk Key, seqNum uint) error {
if s.keys == nil {
s.keys = make(map[Key]map[uint]Key)
s.keys = make(map[Key]map[uint]InMemoryKey)
}
if _, ok := s.keys[pubKey]; !ok {
s.keys[pubKey] = make(map[uint]Key)
s.keys[pubKey] = make(map[uint]InMemoryKey)
}
s.keys[pubKey][msgNum] = InMemoryKey{
sessionID: sessionID,
messageKey: mk,
seqNum: seqNum,
}
s.keys[pubKey][msgNum] = mk
return nil
}
// See KeysStorage.
// DeleteMk ensures there's no message key under the specified key and msgNum.
func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error {
if s.keys == nil {
return nil
@ -72,19 +90,58 @@ func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error {
return nil
}
// See KeysStorage.
func (s *KeysStorageInMemory) DeletePk(pubKey Key) error {
if s.keys == nil {
// TruncateMks truncates the number of keys to maxKeys.
func (s *KeysStorageInMemory) TruncateMks(sessionID []byte, maxKeys int) error {
var seqNos []uint
// Collect all seq numbers
for _, keys := range s.keys {
for _, inMemoryKey := range keys {
if bytes.Equal(inMemoryKey.sessionID, sessionID) {
seqNos = append(seqNos, inMemoryKey.seqNum)
}
}
}
// Nothing to do if we haven't reached the limit
if len(seqNos) <= maxKeys {
return nil
}
if _, ok := s.keys[pubKey]; !ok {
return nil
// Take the sequence numbers we care about
sort.Slice(seqNos, func(i, j int) bool { return seqNos[i] < seqNos[j] })
toDeleteSlice := seqNos[:len(seqNos)-maxKeys]
// Put in map for easier lookup
toDelete := make(map[uint]bool)
for _, seqNo := range toDeleteSlice {
toDelete[seqNo] = true
}
delete(s.keys, pubKey)
for pubKey, keys := range s.keys {
for i, inMemoryKey := range keys {
if toDelete[inMemoryKey.seqNum] && bytes.Equal(inMemoryKey.sessionID, sessionID) {
delete(s.keys[pubKey], i)
}
}
}
return nil
}
// See KeysStorage.
// DeleteOldMKeys deletes old message keys for a session.
func (s *KeysStorageInMemory) DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error {
for pubKey, keys := range s.keys {
for i, inMemoryKey := range keys {
if inMemoryKey.seqNum <= deleteUntilSeqKey && bytes.Equal(inMemoryKey.sessionID, sessionID) {
delete(s.keys[pubKey], i)
}
}
}
return nil
}
// Count returns number of message keys stored under the specified key.
func (s *KeysStorageInMemory) Count(pubKey Key) (uint, error) {
if s.keys == nil {
return 0, nil
@ -92,7 +149,16 @@ func (s *KeysStorageInMemory) Count(pubKey Key) (uint, error) {
return uint(len(s.keys[pubKey])), nil
}
// See KeysStorage.
// All returns all the keys
func (s *KeysStorageInMemory) All() (map[Key]map[uint]Key, error) {
return s.keys, nil
response := make(map[Key]map[uint]Key)
for pubKey, keys := range s.keys {
response[pubKey] = make(map[uint]Key)
for n, key := range keys {
response[pubKey][n] = key.messageKey
}
}
return response, nil
}

View File

@ -17,7 +17,7 @@ func WithMaxSkip(n int) option {
}
}
// WithMaxKeep specifies the maximum number of ratchet steps before a message is deleted.
// WithMaxKeep specifies how long we keep message keys, counted in number of messages received
// nolint: golint
func WithMaxKeep(n int) option {
return func(s *State) error {
@ -29,6 +29,18 @@ func WithMaxKeep(n int) option {
}
}
// WithMaxMessageKeysPerSession specifies the maximum number of message keys per session
// nolint: golint
func WithMaxMessageKeysPerSession(n int) option {
return func(s *State) error {
if n < 0 {
return fmt.Errorf("n must be non-negative")
}
s.MaxMessageKeysPerSession = n
return nil
}
}
// WithKeysStorage replaces the default keys storage with the specified.
// nolint: golint
func WithKeysStorage(ks KeysStorage) option {

View File

@ -130,7 +130,6 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
)
// Is there a new ratchet key?
isDHStepped := false
if m.Header.DH != sc.DHr {
if skippedKeys1, err = sc.skipMessageKeys(sc.DHr, uint(m.Header.PN)); err != nil {
return nil, fmt.Errorf("can't skip previous chain message keys: %s", err)
@ -138,7 +137,6 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
if err = sc.dhRatchet(m.Header); err != nil {
return nil, fmt.Errorf("can't perform ratchet step: %s", err)
}
isDHStepped = true
}
// After all, update the current chain.
@ -151,17 +149,13 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
return nil, fmt.Errorf("can't decrypt: %s", err)
}
// Apply changes.
if err := s.applyChanges(sc, append(skippedKeys1, skippedKeys2...)); err != nil {
return nil, err
}
// Increment the number of keys
sc.KeysCount++
if isDHStepped {
err = s.deleteSkippedKeys(s.DHr)
if err != nil {
// Apply changes.
if err := s.applyChanges(sc, s.id, append(skippedKeys1, skippedKeys2...)); err != nil {
return nil, err
}
}
// Store state
if err := s.store(); err != nil {

View File

@ -103,12 +103,9 @@ func (s *sessionHE) RatchetDecrypt(m MessageHE, ad []byte) ([]byte, error) {
return nil, fmt.Errorf("can't decrypt: %s", err)
}
if err = s.applyChanges(sc, append(skippedKeys1, skippedKeys2...)); err != nil {
if err = s.applyChanges(sc, []byte("FIXME"), append(skippedKeys1, skippedKeys2...)); err != nil {
return nil, fmt.Errorf("failed to apply changes: %s", err)
}
if step {
_ = s.deleteSkippedKeys(s.HKr)
}
return plaintext, nil
}

View File

@ -42,14 +42,18 @@ type State struct {
// Sending header key and next header key. Only used for header encryption.
HKs, NHKs Key
// Number of ratchet steps after which all skipped message keys for that key will be deleted.
// How long we keep messages keys, counted in number of messages received,
// for example if MaxKeep is 5 we only keep the last 5 messages keys, deleting everything n - 5.
MaxKeep uint
// Max number of message keys per session, older keys will be deleted in FIFO fashion
MaxMessageKeysPerSession int
// The number of the current ratchet step.
Step uint
// Which key for the receiving chain was used at the specified step.
DeleteKeys map[uint]Key
// KeysCount the number of keys generated for decrypting
KeysCount uint
}
func DefaultState(sharedKey Key) State {
@ -65,8 +69,9 @@ func DefaultState(sharedKey Key) State {
RecvCh: kdfChain{CK: sharedKey, Crypto: c},
MkSkipped: &KeysStorageInMemory{},
MaxSkip: 1000,
MaxKeep: 100,
DeleteKeys: make(map[uint]Key),
MaxMessageKeysPerSession: 2000,
MaxKeep: 2000,
KeysCount: 0,
}
}
@ -112,6 +117,7 @@ type skippedKey struct {
key Key
nr uint
mk Key
seq uint
}
// skipMessageKeys skips message keys in the current receiving chain.
@ -119,14 +125,11 @@ func (s *State) skipMessageKeys(key Key, until uint) ([]skippedKey, error) {
if until < uint(s.RecvCh.N) {
return nil, fmt.Errorf("bad until: probably an out-of-order message that was deleted")
}
nSkipped, err := s.MkSkipped.Count(key)
if err != nil {
return nil, err
}
if until-uint(s.RecvCh.N)+nSkipped > s.MaxSkip {
if uint(s.RecvCh.N)+s.MaxSkip < until {
return nil, fmt.Errorf("too many messages")
}
skipped := []skippedKey{}
for uint(s.RecvCh.N) < until {
mk := s.RecvCh.step()
@ -134,32 +137,31 @@ func (s *State) skipMessageKeys(key Key, until uint) ([]skippedKey, error) {
key: key,
nr: uint(s.RecvCh.N - 1),
mk: mk,
seq: s.KeysCount,
})
// Increment key count
s.KeysCount++
}
return skipped, nil
}
func (s *State) applyChanges(sc State, skipped []skippedKey) error {
func (s *State) applyChanges(sc State, sessionID []byte, skipped []skippedKey) error {
*s = sc
for _, skipped := range skipped {
if err := s.MkSkipped.Put(skipped.key, skipped.nr, skipped.mk); err != nil {
if err := s.MkSkipped.Put(sessionID, skipped.key, skipped.nr, skipped.mk, skipped.seq); err != nil {
return err
}
}
if err := s.MkSkipped.TruncateMks(sessionID, s.MaxMessageKeysPerSession); err != nil {
return err
}
if s.KeysCount >= s.MaxKeep {
if err := s.MkSkipped.DeleteOldMks(sessionID, s.KeysCount-s.MaxKeep); err != nil {
return err
}
}
return nil
}
func (s *State) deleteSkippedKeys(key Key) error {
s.DeleteKeys[s.Step] = key
s.Step++
if hk, ok := s.DeleteKeys[s.Step-s.MaxKeep]; ok {
if err := s.MkSkipped.DeletePk(hk); err != nil {
return err
}
delete(s.DeleteKeys, s.Step-s.MaxKeep)
}
return nil
}