Change handling of skipped/deleted keys & add version (#1261)
- Skipped keys The purpose of limiting the number of skipped keys generated is to avoid a dos attack whereby an attacker would send a large N, forcing the device to compute all the keys between currentN..N . Previously the logic for handling skipped keys was: - If in the current receiving chain there are more than maxSkip keys, throw an error This is problematic as in long-lived session dropped/unreceived messages starts piling up, eventually reaching the threshold (1000 dropped/unreceived messages). This logic has been changed to be more inline with signals spec, and now it is: - If N is > currentN + maxSkip, throw an error The purpose of limiting the number of skipped keys stored is to avoid a dos attack whereby an attacker would force us to store a large number of keys, filling up our storage. Previously the logic for handling old keys was: - Once you have maxKeep ratchet steps, delete any key from currentRatchet - maxKeep. This, in combination with the maxSkip implementation, capped the number of stored keys to maxSkip * maxKeep. The logic has been changed to: - Keep a maximum of MaxMessageKeysPerSession and additionally we delete any key that has a sequence number < currentSeqNum - maxKeep - Version We check now the version of the bundle so that when we get a bundle from the same installationID with a higher version, we mark the previous bundle as expired and use the new bundle the next time a message is sent
This commit is contained in:
parent
58bd36e79e
commit
ee3c05c79b
|
@ -782,11 +782,11 @@
|
|||
version = "v1.1"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:b80bfc6278c27edcb1946fc38bc61f4c3a1cd50978083a4a79b7c875c70ed9d8"
|
||||
digest = "1:41fb72d7a71f37f1f9c766d965178636ecda21b429b1f2e3fff42cfc31279751"
|
||||
name = "github.com/status-im/doubleratchet"
|
||||
packages = ["."]
|
||||
pruneopts = "NUT"
|
||||
revision = "321788dbb6eac36f7dab04e631db139e13bb280b"
|
||||
revision = "4dcb6cba284ae9f97129e2a98b9277f629d9dbc4"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
|
|
|
@ -160,7 +160,7 @@
|
|||
|
||||
[[constraint]]
|
||||
name = "github.com/status-im/doubleratchet"
|
||||
revision = "321788dbb6eac36f7dab04e631db139e13bb280b"
|
||||
revision = "4dcb6cba284ae9f97129e2a98b9277f629d9dbc4"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/status-im/migrate"
|
||||
|
|
|
@ -21,6 +21,15 @@ var ErrSessionNotFound = errors.New("session not found")
|
|||
// If we have no bundles, we use a constant so that the message can reach any device
|
||||
const noInstallationID = "none"
|
||||
|
||||
// How many consecutive messages can be skipped in the receiving chain
|
||||
const maxSkip = 1000
|
||||
|
||||
// Any message with seqNo <= currentSeq - maxKeep will be deleted
|
||||
const maxKeep = 3000
|
||||
|
||||
// How many keys do we store in total per session
|
||||
const maxMessageKeysPerSession = 2000
|
||||
|
||||
// EncryptionService defines a service that is responsible for the encryption aspect of the protocol
|
||||
type EncryptionService struct {
|
||||
log log.Logger
|
||||
|
@ -251,9 +260,9 @@ func (s *EncryptionService) createNewSession(drInfo *RatchetInfo, sk [32]byte, k
|
|||
keyPair,
|
||||
s.persistence.GetSessionStorage(),
|
||||
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
||||
// TODO: Temporarily increase to a high number, until
|
||||
// we make sure it's a sliding window rather than dropping
|
||||
dr.WithMaxSkip(10000),
|
||||
dr.WithMaxSkip(maxSkip),
|
||||
dr.WithMaxKeep(maxKeep),
|
||||
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
|
||||
dr.WithCrypto(crypto.EthereumCrypto{}))
|
||||
} else {
|
||||
session, err = dr.NewWithRemoteKey(
|
||||
|
@ -262,9 +271,9 @@ func (s *EncryptionService) createNewSession(drInfo *RatchetInfo, sk [32]byte, k
|
|||
keyPair.PubKey,
|
||||
s.persistence.GetSessionStorage(),
|
||||
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
||||
// TODO: Temporarily increase to a high number, until
|
||||
// we make sure it's a sliding window rather than dropping
|
||||
dr.WithMaxSkip(10000),
|
||||
dr.WithMaxSkip(maxSkip),
|
||||
dr.WithMaxKeep(maxKeep),
|
||||
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
|
||||
dr.WithCrypto(crypto.EthereumCrypto{}))
|
||||
}
|
||||
|
||||
|
@ -291,9 +300,9 @@ func (s *EncryptionService) encryptUsingDR(theirIdentityKey *ecdsa.PublicKey, dr
|
|||
drInfo.ID,
|
||||
sessionStorage,
|
||||
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
||||
// TODO: Temporarily increase to a high number, until
|
||||
// we make sure it's a sliding window rather than dropping
|
||||
dr.WithMaxSkip(10000),
|
||||
dr.WithMaxSkip(maxSkip),
|
||||
dr.WithMaxKeep(maxKeep),
|
||||
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
|
||||
dr.WithCrypto(crypto.EthereumCrypto{}),
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -342,9 +351,9 @@ func (s *EncryptionService) decryptUsingDR(theirIdentityKey *ecdsa.PublicKey, dr
|
|||
drInfo.ID,
|
||||
sessionStorage,
|
||||
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
||||
// TODO: Temporarily increase to a high number, until
|
||||
// we make sure it's a sliding window rather than dropping
|
||||
dr.WithMaxSkip(10000),
|
||||
dr.WithMaxSkip(maxSkip),
|
||||
dr.WithMaxKeep(maxKeep),
|
||||
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
|
||||
dr.WithCrypto(crypto.EthereumCrypto{}),
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
@ -314,6 +314,213 @@ func (s *EncryptionServiceTestSuite) TestConversation() {
|
|||
s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH")
|
||||
}
|
||||
|
||||
// Previous implementation allowed max maxSkip keys in the same receiving chain
|
||||
// leading to a problem whereby dropped messages would accumulate and eventually
|
||||
// we would not be able to decrypt any new message anymore.
|
||||
// Here we are testing that maxSkip only applies to *consecutive* messages, not
|
||||
// overall.
|
||||
func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() {
|
||||
bobText := []byte("text")
|
||||
|
||||
bobKey, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
aliceKey, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create a bundle
|
||||
bobBundle, err := s.bob.CreateBundle(bobKey)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We add bob bundle
|
||||
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create a bundle
|
||||
aliceBundle, err := s.alice.CreateBundle(aliceKey)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We add alice bundle
|
||||
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Bob sends a message
|
||||
|
||||
for i := 0; i < maxSkip; i++ {
|
||||
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// Bob sends a message
|
||||
bobMessage1, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Alice receives the message
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Bob sends a message
|
||||
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Bob sends a message
|
||||
bobMessage2, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Alice receives the message, we should have maxSkip + 1 keys in the db, but
|
||||
// we should not throw an error
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// Test that an error is thrown if max skip is reached
|
||||
func (s *EncryptionServiceTestSuite) TestMaxSkipKeysError() {
|
||||
bobText := []byte("text")
|
||||
|
||||
bobKey, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
aliceKey, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create a bundle
|
||||
bobBundle, err := s.bob.CreateBundle(bobKey)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We add bob bundle
|
||||
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create a bundle
|
||||
aliceBundle, err := s.alice.CreateBundle(aliceKey)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We add alice bundle
|
||||
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Bob sends a message
|
||||
|
||||
for i := 0; i < maxSkip+1; i++ {
|
||||
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// Bob sends a message
|
||||
bobMessage1, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Alice receives the message
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1)
|
||||
s.Require().Equal(errors.New("can't skip current chain message keys: too many messages"), err)
|
||||
}
|
||||
|
||||
func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() {
|
||||
bobText := []byte("text")
|
||||
|
||||
bobKey, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
aliceKey, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create a bundle
|
||||
bobBundle, err := s.bob.CreateBundle(bobKey)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We add bob bundle
|
||||
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create a bundle
|
||||
aliceBundle, err := s.alice.CreateBundle(aliceKey)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We add alice bundle
|
||||
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We create just enough messages so that the first key should be deleted
|
||||
|
||||
nMessages := maxMessageKeysPerSession + maxMessageKeysPerSession/maxSkip + 2
|
||||
messages := make([]map[string]*DirectMessageProtocol, nMessages)
|
||||
for i := 0; i < nMessages; i++ {
|
||||
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||
s.Require().NoError(err)
|
||||
|
||||
messages[i] = m
|
||||
|
||||
// We decrypt some messages otherwise we hit maxSkip limit
|
||||
if i%maxSkip == 0 {
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Another message to trigger the deletion
|
||||
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We decrypt the first message, and it should fail
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1])
|
||||
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
|
||||
|
||||
// We decrypt the second message, and it should be decrypted
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[2])
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *EncryptionServiceTestSuite) TestMaxKeep() {
|
||||
bobText := []byte("text")
|
||||
|
||||
bobKey, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
aliceKey, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create a bundle
|
||||
bobBundle, err := s.bob.CreateBundle(bobKey)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We add bob bundle
|
||||
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create a bundle
|
||||
aliceBundle, err := s.alice.CreateBundle(aliceKey)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We add alice bundle
|
||||
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We decrypt all messages but 1 & 2
|
||||
messages := make([]map[string]*DirectMessageProtocol, maxKeep)
|
||||
for i := 0; i < maxKeep; i++ {
|
||||
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||
messages[i] = m
|
||||
s.Require().NoError(err)
|
||||
|
||||
if i != 0 && i != 1 {
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// We decrypt the first message, and it should fail, as it should have been removed
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[0])
|
||||
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
|
||||
|
||||
// We decrypt the second message, and it should be decrypted
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1])
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// Alice has Bob's bundle
|
||||
// Bob has Alice's bundle
|
||||
// Bob sends a message to alice
|
||||
|
@ -557,6 +764,9 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
|
|||
|
||||
bobBundle2, err := NewBundleContainer(bobKey, bobInstallationID)
|
||||
s.Require().NoError(err)
|
||||
// We set the version
|
||||
|
||||
bobBundle2.GetBundle().GetSignedPreKeys()[bobInstallationID].Version = 1
|
||||
|
||||
err = SignBundle(bobKey, bobBundle2)
|
||||
s.Require().NoError(err)
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
// 1536754952_initial_schema.up.sql
|
||||
// 1539249977_update_ratchet_info.down.sql
|
||||
// 1539249977_update_ratchet_info.up.sql
|
||||
// 1540715431_add_version.down.sql
|
||||
// 1540715431_add_version.up.sql
|
||||
// static.go
|
||||
// DO NOT EDIT!
|
||||
|
||||
|
@ -87,7 +89,7 @@ func _1536754952_initial_schemaDownSql() (*asset, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "1536754952_initial_schema.down.sql", size: 83, mode: os.FileMode(420), modTime: time.Unix(1537862328, 0)}
|
||||
info := bindataFileInfo{name: "1536754952_initial_schema.down.sql", size: 83, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
@ -107,7 +109,7 @@ func _1536754952_initial_schemaUpSql() (*asset, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "1536754952_initial_schema.up.sql", size: 962, mode: os.FileMode(420), modTime: time.Unix(1539252806, 0)}
|
||||
info := bindataFileInfo{name: "1536754952_initial_schema.up.sql", size: 962, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
@ -127,7 +129,7 @@ func _1539249977_update_ratchet_infoDownSql() (*asset, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "1539249977_update_ratchet_info.down.sql", size: 311, mode: os.FileMode(420), modTime: time.Unix(1539250187, 0)}
|
||||
info := bindataFileInfo{name: "1539249977_update_ratchet_info.down.sql", size: 311, mode: os.FileMode(420), modTime: time.Unix(1540738831, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
@ -147,7 +149,47 @@ func _1539249977_update_ratchet_infoUpSql() (*asset, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "1539249977_update_ratchet_info.up.sql", size: 368, mode: os.FileMode(420), modTime: time.Unix(1539250201, 0)}
|
||||
info := bindataFileInfo{name: "1539249977_update_ratchet_info.up.sql", size: 368, mode: os.FileMode(420), modTime: time.Unix(1540738831, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
var __1540715431_add_versionDownSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x72\xf4\x09\x71\x0d\x52\x08\x71\x74\xf2\x71\x55\xc8\x4e\xad\x2c\x56\x70\x09\xf2\x0f\x50\x70\xf6\xf7\x09\xf5\xf5\x53\x28\x4e\x2d\x2e\xce\xcc\xcf\x8b\xcf\x4c\xb1\xe6\x42\x56\x08\x15\x47\x55\x0c\xd2\x1d\x9f\x9c\x5f\x9a\x57\x82\xaa\x38\xa9\x34\x2f\x25\x27\x15\x55\x6d\x59\x6a\x11\xc8\x00\x6b\x2e\x40\x00\x00\x00\xff\xff\xda\x5d\x80\x2d\x7f\x00\x00\x00")
|
||||
|
||||
func _1540715431_add_versionDownSqlBytes() ([]byte, error) {
|
||||
return bindataRead(
|
||||
__1540715431_add_versionDownSql,
|
||||
"1540715431_add_version.down.sql",
|
||||
)
|
||||
}
|
||||
|
||||
func _1540715431_add_versionDownSql() (*asset, error) {
|
||||
bytes, err := _1540715431_add_versionDownSqlBytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "1540715431_add_version.down.sql", size: 127, mode: os.FileMode(420), modTime: time.Unix(1540989119, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
var __1540715431_add_versionUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x8c\xcd\xb1\x0e\x02\x21\x0c\xc6\xf1\xdd\xa7\xf8\x1e\xc1\xdd\x09\xa4\x67\x4c\x7a\x90\x90\x32\x93\xe8\x31\x5c\x54\x2e\x8a\x98\xf8\xf6\x06\xe3\xc2\xa2\xae\x6d\xff\xbf\x1a\x62\x12\xc2\xe0\xdd\x88\x53\x7a\x96\xcd\x4a\xb1\x90\x87\x28\xcd\xf4\x9e\x40\x19\x83\xad\xe3\x30\x5a\x94\x74\x8d\xb9\x5e\xb0\xb7\x42\x3b\xf2\xb0\x4e\x60\x03\x33\x0c\x0d\x2a\xb0\x60\xfd\xab\x2f\x65\x5e\x72\x9c\x27\x68\x76\xba\x3f\xfe\x2c\xbb\xa0\x01\xf1\xb8\xd4\x7c\xff\xfb\xe7\xa1\xe6\xe9\x9c\x3a\xe5\x91\x6e\x4d\xfe\x4a\xbc\x02\x00\x00\xff\xff\x0e\x27\x2c\x52\x09\x01\x00\x00")
|
||||
|
||||
func _1540715431_add_versionUpSqlBytes() ([]byte, error) {
|
||||
return bindataRead(
|
||||
__1540715431_add_versionUpSql,
|
||||
"1540715431_add_version.up.sql",
|
||||
)
|
||||
}
|
||||
|
||||
func _1540715431_add_versionUpSql() (*asset, error) {
|
||||
bytes, err := _1540715431_add_versionUpSqlBytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "1540715431_add_version.up.sql", size: 265, mode: os.FileMode(420), modTime: time.Unix(1540989075, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
@ -167,7 +209,7 @@ func staticGo() (*asset, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "static.go", size: 188, mode: os.FileMode(420), modTime: time.Unix(1537862328, 0)}
|
||||
info := bindataFileInfo{name: "static.go", size: 188, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
@ -228,6 +270,8 @@ var _bindata = map[string]func() (*asset, error){
|
|||
"1536754952_initial_schema.up.sql": _1536754952_initial_schemaUpSql,
|
||||
"1539249977_update_ratchet_info.down.sql": _1539249977_update_ratchet_infoDownSql,
|
||||
"1539249977_update_ratchet_info.up.sql": _1539249977_update_ratchet_infoUpSql,
|
||||
"1540715431_add_version.down.sql": _1540715431_add_versionDownSql,
|
||||
"1540715431_add_version.up.sql": _1540715431_add_versionUpSql,
|
||||
"static.go": staticGo,
|
||||
}
|
||||
|
||||
|
@ -275,6 +319,8 @@ var _bintree = &bintree{nil, map[string]*bintree{
|
|||
"1536754952_initial_schema.up.sql": &bintree{_1536754952_initial_schemaUpSql, map[string]*bintree{}},
|
||||
"1539249977_update_ratchet_info.down.sql": &bintree{_1539249977_update_ratchet_infoDownSql, map[string]*bintree{}},
|
||||
"1539249977_update_ratchet_info.up.sql": &bintree{_1539249977_update_ratchet_infoUpSql, map[string]*bintree{}},
|
||||
"1540715431_add_version.down.sql": &bintree{_1540715431_add_versionDownSql, map[string]*bintree{}},
|
||||
"1540715431_add_version.up.sql": &bintree{_1540715431_add_versionUpSql, map[string]*bintree{}},
|
||||
"static.go": &bintree{staticGo, map[string]*bintree{}},
|
||||
}}
|
||||
|
||||
|
|
|
@ -61,8 +61,7 @@ func (p *ProtocolService) BuildPublicMessage(myIdentityKey *ecdsa.PrivateKey, pa
|
|||
// BuildDirectMessage marshals a 1:1 chat message given the user identity private key, the recipient's public key, and a payload
|
||||
func (p *ProtocolService) BuildDirectMessage(myIdentityKey *ecdsa.PrivateKey, payload []byte, theirPublicKeys ...*ecdsa.PublicKey) (map[*ecdsa.PublicKey][]byte, error) {
|
||||
response := make(map[*ecdsa.PublicKey][]byte)
|
||||
publicKeys := append(theirPublicKeys, &myIdentityKey.PublicKey)
|
||||
for _, publicKey := range publicKeys {
|
||||
for _, publicKey := range theirPublicKeys {
|
||||
// Encrypt payload
|
||||
encryptionResponse, err := p.encryption.EncryptPayload(publicKey, myIdentityKey, payload)
|
||||
if err != nil {
|
||||
|
|
|
@ -56,7 +56,7 @@ func (s *ProtocolServiceTestSuite) TestBuildDirectMessage() {
|
|||
})
|
||||
s.NoError(err)
|
||||
|
||||
marshaledMsg, err := s.alice.BuildDirectMessage(aliceKey, payload, &bobKey.PublicKey)
|
||||
marshaledMsg, err := s.alice.BuildDirectMessage(aliceKey, payload, &bobKey.PublicKey, &aliceKey.PublicKey)
|
||||
|
||||
s.NoError(err)
|
||||
s.NotNil(marshaledMsg, "It creates a message")
|
||||
|
|
|
@ -16,6 +16,9 @@ import (
|
|||
"github.com/status-im/status-go/services/shhext/chat/migrations"
|
||||
)
|
||||
|
||||
// A safe max number of rows
|
||||
const maxNumberOfRows = 100000000
|
||||
|
||||
// SQLLitePersistence represents a persistence service tied to an SQLite database
|
||||
type SQLLitePersistence struct {
|
||||
db *sql.DB
|
||||
|
@ -107,7 +110,20 @@ func (s *SQLLitePersistence) AddPrivateBundle(b *BundleContainer) error {
|
|||
}
|
||||
|
||||
for installationID, signedPreKey := range b.GetBundle().GetSignedPreKeys() {
|
||||
stmt, err := tx.Prepare("INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, timestamp) VALUES(?, ?, ?, ?, ?)")
|
||||
var version uint32
|
||||
stmt, err := tx.Prepare("SELECT version FROM bundles WHERE installation_id = ? AND identity = ? ORDER BY version DESC LIMIT 1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer stmt.Close()
|
||||
|
||||
err = stmt.QueryRow(installationID, b.GetBundle().GetIdentity()).Scan(&version)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
|
||||
stmt, err = tx.Prepare("INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, version, timestamp) VALUES(?, ?, ?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -118,6 +134,7 @@ func (s *SQLLitePersistence) AddPrivateBundle(b *BundleContainer) error {
|
|||
b.GetPrivateSignedPreKey(),
|
||||
signedPreKey.GetSignedPreKey(),
|
||||
installationID,
|
||||
version+1,
|
||||
time.Now().UnixNano(),
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -144,15 +161,18 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
|
|||
|
||||
for installationID, signedPreKeyContainer := range b.GetSignedPreKeys() {
|
||||
signedPreKey := signedPreKeyContainer.GetSignedPreKey()
|
||||
insertStmt, err := tx.Prepare("INSERT INTO bundles(identity, signed_pre_key, installation_id, timestamp) VALUES( ?, ?, ?, ?)")
|
||||
version := signedPreKeyContainer.GetVersion()
|
||||
insertStmt, err := tx.Prepare("INSERT INTO bundles(identity, signed_pre_key, installation_id, version, timestamp) VALUES( ?, ?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer insertStmt.Close()
|
||||
|
||||
_, err = insertStmt.Exec(
|
||||
b.GetIdentity(),
|
||||
signedPreKey,
|
||||
installationID,
|
||||
version,
|
||||
time.Now().UnixNano(),
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -160,7 +180,7 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
|
|||
return err
|
||||
}
|
||||
// Mark old bundles as expired
|
||||
updateStmt, err := tx.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND installation_id = ? AND signed_pre_key != ?")
|
||||
updateStmt, err := tx.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND installation_id = ? AND version < ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -169,7 +189,7 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
|
|||
_, err = updateStmt.Exec(
|
||||
b.GetIdentity(),
|
||||
installationID,
|
||||
signedPreKey,
|
||||
version,
|
||||
)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
|
@ -281,7 +301,7 @@ func (s *SQLLitePersistence) MarkBundleExpired(identity []byte) error {
|
|||
func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey) (*Bundle, error) {
|
||||
|
||||
identity := crypto.CompressPubkey(publicKey)
|
||||
stmt, err := s.db.Prepare("SELECT signed_pre_key,installation_id FROM bundles WHERE expired = 0 AND identity = ? ORDER BY timestamp DESC")
|
||||
stmt, err := s.db.Prepare("SELECT signed_pre_key,installation_id, version FROM bundles WHERE expired = 0 AND identity = ? ORDER BY version DESC")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -304,16 +324,21 @@ func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey) (*Bundl
|
|||
for rows.Next() {
|
||||
var signedPreKey []byte
|
||||
var installationID string
|
||||
var version uint32
|
||||
rowCount++
|
||||
err = rows.Scan(
|
||||
&signedPreKey,
|
||||
&installationID,
|
||||
&version,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bundle.SignedPreKeys[installationID] = &SignedPreKey{SignedPreKey: signedPreKey}
|
||||
bundle.SignedPreKeys[installationID] = &SignedPreKey{
|
||||
SignedPreKey: signedPreKey,
|
||||
Version: version,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -448,17 +473,53 @@ func (s *SQLLiteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, erro
|
|||
}
|
||||
|
||||
// Put stores a key with the specified public key, message number and message key
|
||||
func (s *SQLLiteKeysStorage) Put(pubKey dr.Key, msgNum uint, mk dr.Key) error {
|
||||
stmt, err := s.db.Prepare("insert into keys(public_key, msg_num, message_key) values(?, ?, ?)")
|
||||
func (s *SQLLiteKeysStorage) Put(sessionID []byte, pubKey dr.Key, msgNum uint, mk dr.Key, seqNum uint) error {
|
||||
stmt, err := s.db.Prepare("insert into keys(session_id, public_key, msg_num, message_key, seq_num) values(?, ?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(
|
||||
sessionID,
|
||||
pubKey[:],
|
||||
msgNum,
|
||||
mk[:],
|
||||
seqNum,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteOldMks caps remove any key < seq_num, included
|
||||
func (s *SQLLiteKeysStorage) DeleteOldMks(sessionID []byte, deleteUntil uint) error {
|
||||
stmt, err := s.db.Prepare("DELETE FROM keys WHERE session_id = ? AND seq_num <= ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(
|
||||
sessionID,
|
||||
deleteUntil,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// TruncateMks caps the number of keys to maxKeysPerSession deleting them in FIFO fashion
|
||||
func (s *SQLLiteKeysStorage) TruncateMks(sessionID []byte, maxKeysPerSession int) error {
|
||||
stmt, err := s.db.Prepare("DELETE FROM keys WHERE rowid IN (SELECT rowid FROM keys WHERE session_id = ? ORDER BY seq_num DESC LIMIT ? OFFSET ?)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(
|
||||
sessionID,
|
||||
// We LIMIT to the max number of rows here, as OFFSET can't be used without a LIMIT
|
||||
maxNumberOfRows,
|
||||
maxKeysPerSession,
|
||||
)
|
||||
|
||||
return err
|
||||
|
@ -480,21 +541,6 @@ func (s *SQLLiteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// DeletePk deletes the keys with the specified public key
|
||||
func (s *SQLLiteKeysStorage) DeletePk(pubKey dr.Key) error {
|
||||
stmt, err := s.db.Prepare("DELETE FROM keys WHERE public_key = ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(
|
||||
pubKey[:],
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Count returns the count of keys with the specified public key
|
||||
func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) {
|
||||
stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys WHERE public_key = ?")
|
||||
|
@ -512,6 +558,23 @@ func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) {
|
|||
return count, nil
|
||||
}
|
||||
|
||||
// CountAll returns the count of keys with the specified public key
|
||||
func (s *SQLLiteKeysStorage) CountAll() (uint, error) {
|
||||
stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
var count uint
|
||||
err = stmt.QueryRow().Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// All returns nil
|
||||
func (s *SQLLiteKeysStorage) All() (map[dr.Key]map[uint]dr.Key, error) {
|
||||
return nil, nil
|
||||
|
@ -525,6 +588,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
|
|||
dhsPrivate := dhs.PrivateKey()
|
||||
pn := state.PN
|
||||
step := state.Step
|
||||
keysCount := state.KeysCount
|
||||
|
||||
rootChainKey := state.RootCh.CK[:]
|
||||
|
||||
|
@ -534,7 +598,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
|
|||
recvChainKey := state.RecvCh.CK[:]
|
||||
recvChainN := state.RecvCh.N
|
||||
|
||||
stmt, err := s.db.Prepare("insert into sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
|
||||
stmt, err := s.db.Prepare("insert into sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -552,6 +616,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
|
|||
recvChainN,
|
||||
pn,
|
||||
step,
|
||||
keysCount,
|
||||
)
|
||||
|
||||
return err
|
||||
|
@ -559,7 +624,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
|
|||
|
||||
// Load retrieves the double ratchet state for a given ID
|
||||
func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
|
||||
stmt, err := s.db.Prepare("SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step FROM sessions WHERE id = ?")
|
||||
stmt, err := s.db.Prepare("SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count FROM sessions WHERE id = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -577,6 +642,7 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
|
|||
recvChainN uint
|
||||
pn uint
|
||||
step uint
|
||||
keysCount uint
|
||||
)
|
||||
|
||||
err = stmt.QueryRow(id).Scan(
|
||||
|
@ -590,6 +656,7 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
|
|||
&recvChainN,
|
||||
&pn,
|
||||
&step,
|
||||
&keysCount,
|
||||
)
|
||||
switch err {
|
||||
case sql.ErrNoRows:
|
||||
|
@ -599,6 +666,7 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
|
|||
|
||||
state.PN = uint32(pn)
|
||||
state.Step = step
|
||||
state.KeysCount = keysCount
|
||||
|
||||
state.DHs = ecrypto.DHPair{
|
||||
PrvKey: toKey(dhsPrivate),
|
||||
|
|
|
@ -10,8 +10,11 @@ import (
|
|||
var (
|
||||
pubKey1 = dr.Key{0xe3, 0xbe, 0xb9, 0x4e, 0x70, 0x17, 0x37, 0xc, 0x1, 0x8f, 0xa9, 0x7e, 0xef, 0x4, 0xfb, 0x23, 0xac, 0xea, 0x28, 0xf7, 0xa9, 0x56, 0xcc, 0x1d, 0x46, 0xf3, 0xb5, 0x1d, 0x7d, 0x7d, 0x5e, 0x2c}
|
||||
pubKey2 = dr.Key{0xec, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||
mk1 = dr.Key{0xeb, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||
mk2 = dr.Key{0xed, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||
mk1 = dr.Key{0x00, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||
mk2 = dr.Key{0x01, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||
mk3 = dr.Key{0x02, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||
mk4 = dr.Key{0x03, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||
mk5 = dr.Key{0x04, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||
)
|
||||
|
||||
func TestSQLLitePersistenceKeysStorageTestSuite(t *testing.T) {
|
||||
|
@ -47,10 +50,84 @@ func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLiteGetMissin
|
|||
|
||||
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Put() {
|
||||
// Act and assert.
|
||||
err := s.service.Put(pubKey1, 0, mk1)
|
||||
err := s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_DeleteOldMks() {
|
||||
// Insert keys out-of-order
|
||||
err := s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1)
|
||||
s.NoError(err)
|
||||
err = s.service.Put([]byte("session-id"), pubKey1, 1, mk2, 2)
|
||||
s.NoError(err)
|
||||
err = s.service.Put([]byte("session-id"), pubKey1, 2, mk3, 20)
|
||||
s.NoError(err)
|
||||
err = s.service.Put([]byte("session-id"), pubKey1, 3, mk4, 21)
|
||||
s.NoError(err)
|
||||
err = s.service.Put([]byte("session-id"), pubKey1, 4, mk5, 22)
|
||||
s.NoError(err)
|
||||
|
||||
err = s.service.DeleteOldMks([]byte("session-id"), 20)
|
||||
s.NoError(err)
|
||||
|
||||
_, ok, err := s.service.Get(pubKey1, 0)
|
||||
s.NoError(err)
|
||||
s.False(ok)
|
||||
|
||||
_, ok, err = s.service.Get(pubKey1, 1)
|
||||
s.NoError(err)
|
||||
s.False(ok)
|
||||
|
||||
_, ok, err = s.service.Get(pubKey1, 2)
|
||||
s.NoError(err)
|
||||
s.False(ok)
|
||||
|
||||
_, ok, err = s.service.Get(pubKey1, 3)
|
||||
s.NoError(err)
|
||||
s.True(ok)
|
||||
|
||||
_, ok, err = s.service.Get(pubKey1, 4)
|
||||
s.NoError(err)
|
||||
s.True(ok)
|
||||
}
|
||||
|
||||
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_TruncateMks() {
|
||||
// Insert keys out-of-order
|
||||
err := s.service.Put([]byte("session-id"), pubKey2, 2, mk5, 5)
|
||||
s.NoError(err)
|
||||
err = s.service.Put([]byte("session-id"), pubKey2, 0, mk3, 3)
|
||||
s.NoError(err)
|
||||
err = s.service.Put([]byte("session-id"), pubKey1, 1, mk2, 2)
|
||||
s.NoError(err)
|
||||
err = s.service.Put([]byte("session-id"), pubKey2, 1, mk4, 4)
|
||||
s.NoError(err)
|
||||
err = s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1)
|
||||
s.NoError(err)
|
||||
|
||||
err = s.service.TruncateMks([]byte("session-id"), 2)
|
||||
s.NoError(err)
|
||||
|
||||
_, ok, err := s.service.Get(pubKey1, 0)
|
||||
s.NoError(err)
|
||||
s.False(ok)
|
||||
|
||||
_, ok, err = s.service.Get(pubKey1, 1)
|
||||
s.NoError(err)
|
||||
s.False(ok)
|
||||
|
||||
_, ok, err = s.service.Get(pubKey2, 0)
|
||||
s.NoError(err)
|
||||
s.False(ok)
|
||||
|
||||
_, ok, err = s.service.Get(pubKey2, 1)
|
||||
s.NoError(err)
|
||||
s.True(ok)
|
||||
|
||||
_, ok, err = s.service.Get(pubKey2, 2)
|
||||
s.NoError(err)
|
||||
s.True(ok)
|
||||
}
|
||||
|
||||
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Count() {
|
||||
|
||||
// Act.
|
||||
|
@ -71,12 +148,8 @@ func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Delete()
|
|||
|
||||
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Flow() {
|
||||
|
||||
// Act and assert.
|
||||
err := s.service.DeletePk(pubKey1)
|
||||
s.NoError(err)
|
||||
|
||||
// Act.
|
||||
err = s.service.Put(pubKey1, 0, mk1)
|
||||
err := s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1)
|
||||
s.NoError(err)
|
||||
|
||||
k, ok, err := s.service.Get(pubKey1, 0)
|
||||
|
@ -124,30 +197,4 @@ func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Flow() {
|
|||
// Assert.
|
||||
s.NoError(err)
|
||||
s.EqualValues(0, cnt)
|
||||
|
||||
// Act.
|
||||
err = s.service.Put(pubKey1, 0, mk1)
|
||||
s.NoError(err)
|
||||
|
||||
err = s.service.Put(pubKey2, 0, mk2)
|
||||
s.NoError(err)
|
||||
|
||||
err = s.service.DeletePk(pubKey1)
|
||||
s.NoError(err)
|
||||
|
||||
err = s.service.DeletePk(pubKey1)
|
||||
s.NoError(err)
|
||||
|
||||
err = s.service.DeletePk(pubKey2)
|
||||
s.NoError(err)
|
||||
|
||||
cn1, err := s.service.Count(pubKey1)
|
||||
s.NoError(err)
|
||||
|
||||
cn2, err := s.service.Count(pubKey2)
|
||||
s.NoError(err)
|
||||
|
||||
// Assert.
|
||||
s.Empty(cn1)
|
||||
s.Empty(cn2)
|
||||
}
|
||||
|
|
|
@ -51,7 +51,7 @@ func (s *SQLLitePersistenceTestSuite) TestPrivateBundle() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
actualKey, err := s.service.GetPrivateKeyBundle([]byte("non-existing"))
|
||||
s.Require().NoError(err, "It does not return an error if the bundle is not there")
|
||||
s.Require().NoError(err, "Error was not returned even though bundle is not there")
|
||||
s.Nil(actualKey)
|
||||
|
||||
anyPrivateBundle, err := s.service.GetAnyPrivateBundle([]byte("non-existing-id"))
|
||||
|
@ -82,7 +82,7 @@ func (s *SQLLitePersistenceTestSuite) TestPublicBundle() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
|
||||
s.Require().NoError(err, "It does not return an error if the bundle is not there")
|
||||
s.Require().NoError(err, "Error was not returned even though bundle is not there")
|
||||
s.Nil(actualBundle)
|
||||
|
||||
bundleContainer, err := NewBundleContainer(key, "1")
|
||||
|
@ -98,12 +98,82 @@ func (s *SQLLitePersistenceTestSuite) TestPublicBundle() {
|
|||
s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys")
|
||||
}
|
||||
|
||||
func (s *SQLLitePersistenceTestSuite) TestUpdatedBundle() {
|
||||
key, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
|
||||
s.Require().NoError(err, "Error was not returned even though bundle is not there")
|
||||
s.Nil(actualBundle)
|
||||
|
||||
// Create & add initial bundle
|
||||
bundleContainer, err := NewBundleContainer(key, "1")
|
||||
s.Require().NoError(err)
|
||||
|
||||
bundle := bundleContainer.GetBundle()
|
||||
err = s.service.AddPublicBundle(bundle)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create & add a new bundle
|
||||
bundleContainer, err = NewBundleContainer(key, "1")
|
||||
s.Require().NoError(err)
|
||||
bundle = bundleContainer.GetBundle()
|
||||
// We set the version
|
||||
bundle.GetSignedPreKeys()["1"].Version = 1
|
||||
|
||||
err = s.service.AddPublicBundle(bundle)
|
||||
s.Require().NoError(err)
|
||||
|
||||
actualBundle, err = s.service.GetPublicBundle(&key.PublicKey)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(bundle.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity")
|
||||
s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys")
|
||||
}
|
||||
|
||||
func (s *SQLLitePersistenceTestSuite) TestOutOfOrderBundles() {
|
||||
key, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
|
||||
s.Require().NoError(err, "Error was not returned even though bundle is not there")
|
||||
s.Nil(actualBundle)
|
||||
|
||||
// Create & add initial bundle
|
||||
bundleContainer, err := NewBundleContainer(key, "1")
|
||||
s.Require().NoError(err)
|
||||
|
||||
bundle1 := bundleContainer.GetBundle()
|
||||
err = s.service.AddPublicBundle(bundle1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create & add a new bundle
|
||||
bundleContainer, err = NewBundleContainer(key, "1")
|
||||
s.Require().NoError(err)
|
||||
|
||||
bundle2 := bundleContainer.GetBundle()
|
||||
// We set the version
|
||||
bundle2.GetSignedPreKeys()["1"].Version = 1
|
||||
|
||||
err = s.service.AddPublicBundle(bundle2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Add again the initial bundle
|
||||
err = s.service.AddPublicBundle(bundle1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
actualBundle, err = s.service.GetPublicBundle(&key.PublicKey)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(bundle2.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity")
|
||||
s.Equal(bundle2.GetSignedPreKeys()["1"].GetVersion(), uint32(1))
|
||||
s.Equal(bundle2.GetSignedPreKeys()["1"].GetSignedPreKey(), actualBundle.GetSignedPreKeys()["1"].GetSignedPreKey(), "It sets the right prekeys")
|
||||
}
|
||||
|
||||
func (s *SQLLitePersistenceTestSuite) TestMultiplePublicBundle() {
|
||||
key, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
|
||||
s.Require().NoError(err, "It does not return an error if the bundle is not there")
|
||||
s.Require().NoError(err, "Error was not returned even though bundle is not there")
|
||||
s.Nil(actualBundle)
|
||||
|
||||
bundleContainer, err := NewBundleContainer(key, "1")
|
||||
|
@ -120,8 +190,10 @@ func (s *SQLLitePersistenceTestSuite) TestMultiplePublicBundle() {
|
|||
// Adding a different bundle
|
||||
bundleContainer, err = NewBundleContainer(key, "1")
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We set the version
|
||||
bundle = bundleContainer.GetBundle()
|
||||
bundle.GetSignedPreKeys()["1"].Version = 1
|
||||
|
||||
err = s.service.AddPublicBundle(bundle)
|
||||
s.Require().NoError(err)
|
||||
|
||||
|
@ -139,7 +211,7 @@ func (s *SQLLitePersistenceTestSuite) TestMultiDevicePublicBundle() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
|
||||
s.Require().NoError(err, "It does not return an error if the bundle is not there")
|
||||
s.Require().NoError(err, "Error was not returned even though bundle is not there")
|
||||
s.Nil(actualBundle)
|
||||
|
||||
bundleContainer, err := NewBundleContainer(key, "1")
|
||||
|
@ -273,4 +345,3 @@ func (s *SQLLitePersistenceTestSuite) TestRatchetInfoNoBundle() {
|
|||
}
|
||||
|
||||
// TODO: Add test for MarkBundleExpired
|
||||
// TODO: Add test for AddPublicBundle checking that it expires previous bundles
|
||||
|
|
|
@ -88,7 +88,7 @@ func ConfigCliFleetEthBetaJson() (*asset, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "../config/cli/fleet-eth.beta.json", size: 3237, mode: os.FileMode(420), modTime: time.Unix(1537514234, 0)}
|
||||
info := bindataFileInfo{name: "../config/cli/fleet-eth.beta.json", size: 3237, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ func ConfigCliFleetEthStagingJson() (*asset, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "../config/cli/fleet-eth.staging.json", size: 1838, mode: os.FileMode(420), modTime: time.Unix(1537514234, 0)}
|
||||
info := bindataFileInfo{name: "../config/cli/fleet-eth.staging.json", size: 1838, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
@ -128,7 +128,7 @@ func ConfigCliFleetEthTestJson() (*asset, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "../config/cli/fleet-eth.test.json", size: 1519, mode: os.FileMode(420), modTime: time.Unix(1537514234, 0)}
|
||||
info := bindataFileInfo{name: "../config/cli/fleet-eth.test.json", size: 1519, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
@ -148,7 +148,7 @@ func ConfigCliLesEnabledJson() (*asset, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "../config/cli/les-enabled.json", size: 58, mode: os.FileMode(420), modTime: time.Unix(1536858252, 0)}
|
||||
info := bindataFileInfo{name: "../config/cli/les-enabled.json", size: 58, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
@ -168,7 +168,7 @@ func ConfigCliMailserverEnabledJson() (*asset, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "../config/cli/mailserver-enabled.json", size: 176, mode: os.FileMode(420), modTime: time.Unix(1538032850, 0)}
|
||||
info := bindataFileInfo{name: "../config/cli/mailserver-enabled.json", size: 176, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
@ -188,7 +188,7 @@ func ConfigStatusChainGenesisJson() (*asset, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
info := bindataFileInfo{name: "../config/status-chain-genesis.json", size: 612, mode: os.FileMode(420), modTime: time.Unix(1536858252, 0)}
|
||||
info := bindataFileInfo{name: "../config/status-chain-genesis.json", size: 612, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||
a := &asset{bytes: bytes, info: info}
|
||||
return a, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
ALTER TABLE keys DROP COLUMN session_id;
|
||||
ALTER TABLE sessions DROP COLUMN keys_count;
|
||||
ALTER TABLE bundles DROP COLUMN version;
|
|
@ -0,0 +1,5 @@
|
|||
DELETE FROM keys;
|
||||
ALTER TABLE keys ADD COLUMN seq_num INTEGER NOT NULL DEFAULT 0;
|
||||
ALTER TABLE keys ADD COLUMN session_id BLOB;
|
||||
ALTER TABLE sessions ADD COLUMN keys_count INTEGER NOT NULL DEFAULT 0;
|
||||
ALTER TABLE bundles ADD COLUMN version INTEGER NOT NULL DEFAULT 0;
|
|
@ -1,18 +1,26 @@
|
|||
package doubleratchet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// KeysStorage is an interface of an abstract in-memory or persistent keys storage.
|
||||
type KeysStorage interface {
|
||||
// Get returns a message key by the given key and message number.
|
||||
Get(k Key, msgNum uint) (mk Key, ok bool, err error)
|
||||
|
||||
// Put saves the given mk under the specified key and msgNum.
|
||||
Put(k Key, msgNum uint, mk Key) error
|
||||
Put(sessionID []byte, k Key, msgNum uint, mk Key, keySeqNum uint) error
|
||||
|
||||
// DeleteMk ensures there's no message key under the specified key and msgNum.
|
||||
DeleteMk(k Key, msgNum uint) error
|
||||
|
||||
// DeletePk ensures there's no message keys under the specified key.
|
||||
DeletePk(k Key) error
|
||||
// DeleteOldMKeys deletes old message keys for a session.
|
||||
DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error
|
||||
|
||||
// TruncateMks truncates the number of keys to maxKeys.
|
||||
TruncateMks(sessionID []byte, maxKeys int) error
|
||||
|
||||
// Count returns number of message keys stored under the specified key.
|
||||
Count(k Key) (uint, error)
|
||||
|
@ -23,10 +31,10 @@ type KeysStorage interface {
|
|||
|
||||
// KeysStorageInMemory is an in-memory message keys storage.
|
||||
type KeysStorageInMemory struct {
|
||||
keys map[Key]map[uint]Key
|
||||
keys map[Key]map[uint]InMemoryKey
|
||||
}
|
||||
|
||||
// See KeysStorage.
|
||||
// Get returns a message key by the given key and message number.
|
||||
func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) {
|
||||
if s.keys == nil {
|
||||
return Key{}, false, nil
|
||||
|
@ -39,22 +47,32 @@ func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) {
|
|||
if !ok {
|
||||
return Key{}, false, nil
|
||||
}
|
||||
return mk, true, nil
|
||||
return mk.messageKey, true, nil
|
||||
}
|
||||
|
||||
// See KeysStorage.
|
||||
func (s *KeysStorageInMemory) Put(pubKey Key, msgNum uint, mk Key) error {
|
||||
type InMemoryKey struct {
|
||||
messageKey Key
|
||||
seqNum uint
|
||||
sessionID []byte
|
||||
}
|
||||
|
||||
// Put saves the given mk under the specified key and msgNum.
|
||||
func (s *KeysStorageInMemory) Put(sessionID []byte, pubKey Key, msgNum uint, mk Key, seqNum uint) error {
|
||||
if s.keys == nil {
|
||||
s.keys = make(map[Key]map[uint]Key)
|
||||
s.keys = make(map[Key]map[uint]InMemoryKey)
|
||||
}
|
||||
if _, ok := s.keys[pubKey]; !ok {
|
||||
s.keys[pubKey] = make(map[uint]Key)
|
||||
s.keys[pubKey] = make(map[uint]InMemoryKey)
|
||||
}
|
||||
s.keys[pubKey][msgNum] = InMemoryKey{
|
||||
sessionID: sessionID,
|
||||
messageKey: mk,
|
||||
seqNum: seqNum,
|
||||
}
|
||||
s.keys[pubKey][msgNum] = mk
|
||||
return nil
|
||||
}
|
||||
|
||||
// See KeysStorage.
|
||||
// DeleteMk ensures there's no message key under the specified key and msgNum.
|
||||
func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error {
|
||||
if s.keys == nil {
|
||||
return nil
|
||||
|
@ -72,19 +90,58 @@ func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// See KeysStorage.
|
||||
func (s *KeysStorageInMemory) DeletePk(pubKey Key) error {
|
||||
if s.keys == nil {
|
||||
return nil
|
||||
// TruncateMks truncates the number of keys to maxKeys.
|
||||
func (s *KeysStorageInMemory) TruncateMks(sessionID []byte, maxKeys int) error {
|
||||
var seqNos []uint
|
||||
// Collect all seq numbers
|
||||
for _, keys := range s.keys {
|
||||
for _, inMemoryKey := range keys {
|
||||
if bytes.Equal(inMemoryKey.sessionID, sessionID) {
|
||||
seqNos = append(seqNos, inMemoryKey.seqNum)
|
||||
}
|
||||
if _, ok := s.keys[pubKey]; !ok {
|
||||
return nil
|
||||
}
|
||||
delete(s.keys, pubKey)
|
||||
}
|
||||
|
||||
// Nothing to do if we haven't reached the limit
|
||||
if len(seqNos) <= maxKeys {
|
||||
return nil
|
||||
}
|
||||
|
||||
// See KeysStorage.
|
||||
// Take the sequence numbers we care about
|
||||
sort.Slice(seqNos, func(i, j int) bool { return seqNos[i] < seqNos[j] })
|
||||
toDeleteSlice := seqNos[:len(seqNos)-maxKeys]
|
||||
|
||||
// Put in map for easier lookup
|
||||
toDelete := make(map[uint]bool)
|
||||
|
||||
for _, seqNo := range toDeleteSlice {
|
||||
toDelete[seqNo] = true
|
||||
}
|
||||
|
||||
for pubKey, keys := range s.keys {
|
||||
for i, inMemoryKey := range keys {
|
||||
if toDelete[inMemoryKey.seqNum] && bytes.Equal(inMemoryKey.sessionID, sessionID) {
|
||||
delete(s.keys[pubKey], i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteOldMKeys deletes old message keys for a session.
|
||||
func (s *KeysStorageInMemory) DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error {
|
||||
for pubKey, keys := range s.keys {
|
||||
for i, inMemoryKey := range keys {
|
||||
if inMemoryKey.seqNum <= deleteUntilSeqKey && bytes.Equal(inMemoryKey.sessionID, sessionID) {
|
||||
delete(s.keys[pubKey], i)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count returns number of message keys stored under the specified key.
|
||||
func (s *KeysStorageInMemory) Count(pubKey Key) (uint, error) {
|
||||
if s.keys == nil {
|
||||
return 0, nil
|
||||
|
@ -92,7 +149,16 @@ func (s *KeysStorageInMemory) Count(pubKey Key) (uint, error) {
|
|||
return uint(len(s.keys[pubKey])), nil
|
||||
}
|
||||
|
||||
// See KeysStorage.
|
||||
// All returns all the keys
|
||||
func (s *KeysStorageInMemory) All() (map[Key]map[uint]Key, error) {
|
||||
return s.keys, nil
|
||||
response := make(map[Key]map[uint]Key)
|
||||
|
||||
for pubKey, keys := range s.keys {
|
||||
response[pubKey] = make(map[uint]Key)
|
||||
for n, key := range keys {
|
||||
response[pubKey][n] = key.messageKey
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ func WithMaxSkip(n int) option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithMaxKeep specifies the maximum number of ratchet steps before a message is deleted.
|
||||
// WithMaxKeep specifies how long we keep message keys, counted in number of messages received
|
||||
// nolint: golint
|
||||
func WithMaxKeep(n int) option {
|
||||
return func(s *State) error {
|
||||
|
@ -29,6 +29,18 @@ func WithMaxKeep(n int) option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithMaxMessageKeysPerSession specifies the maximum number of message keys per session
|
||||
// nolint: golint
|
||||
func WithMaxMessageKeysPerSession(n int) option {
|
||||
return func(s *State) error {
|
||||
if n < 0 {
|
||||
return fmt.Errorf("n must be non-negative")
|
||||
}
|
||||
s.MaxMessageKeysPerSession = n
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithKeysStorage replaces the default keys storage with the specified.
|
||||
// nolint: golint
|
||||
func WithKeysStorage(ks KeysStorage) option {
|
||||
|
|
|
@ -130,7 +130,6 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
|
|||
)
|
||||
|
||||
// Is there a new ratchet key?
|
||||
isDHStepped := false
|
||||
if m.Header.DH != sc.DHr {
|
||||
if skippedKeys1, err = sc.skipMessageKeys(sc.DHr, uint(m.Header.PN)); err != nil {
|
||||
return nil, fmt.Errorf("can't skip previous chain message keys: %s", err)
|
||||
|
@ -138,7 +137,6 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
|
|||
if err = sc.dhRatchet(m.Header); err != nil {
|
||||
return nil, fmt.Errorf("can't perform ratchet step: %s", err)
|
||||
}
|
||||
isDHStepped = true
|
||||
}
|
||||
|
||||
// After all, update the current chain.
|
||||
|
@ -151,17 +149,13 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
|
|||
return nil, fmt.Errorf("can't decrypt: %s", err)
|
||||
}
|
||||
|
||||
// Apply changes.
|
||||
if err := s.applyChanges(sc, append(skippedKeys1, skippedKeys2...)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Increment the number of keys
|
||||
sc.KeysCount++
|
||||
|
||||
if isDHStepped {
|
||||
err = s.deleteSkippedKeys(s.DHr)
|
||||
if err != nil {
|
||||
// Apply changes.
|
||||
if err := s.applyChanges(sc, s.id, append(skippedKeys1, skippedKeys2...)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Store state
|
||||
if err := s.store(); err != nil {
|
||||
|
|
|
@ -103,12 +103,9 @@ func (s *sessionHE) RatchetDecrypt(m MessageHE, ad []byte) ([]byte, error) {
|
|||
return nil, fmt.Errorf("can't decrypt: %s", err)
|
||||
}
|
||||
|
||||
if err = s.applyChanges(sc, append(skippedKeys1, skippedKeys2...)); err != nil {
|
||||
if err = s.applyChanges(sc, []byte("FIXME"), append(skippedKeys1, skippedKeys2...)); err != nil {
|
||||
return nil, fmt.Errorf("failed to apply changes: %s", err)
|
||||
}
|
||||
if step {
|
||||
_ = s.deleteSkippedKeys(s.HKr)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
|
|
@ -42,14 +42,18 @@ type State struct {
|
|||
// Sending header key and next header key. Only used for header encryption.
|
||||
HKs, NHKs Key
|
||||
|
||||
// Number of ratchet steps after which all skipped message keys for that key will be deleted.
|
||||
// How long we keep messages keys, counted in number of messages received,
|
||||
// for example if MaxKeep is 5 we only keep the last 5 messages keys, deleting everything n - 5.
|
||||
MaxKeep uint
|
||||
|
||||
// Max number of message keys per session, older keys will be deleted in FIFO fashion
|
||||
MaxMessageKeysPerSession int
|
||||
|
||||
// The number of the current ratchet step.
|
||||
Step uint
|
||||
|
||||
// Which key for the receiving chain was used at the specified step.
|
||||
DeleteKeys map[uint]Key
|
||||
// KeysCount the number of keys generated for decrypting
|
||||
KeysCount uint
|
||||
}
|
||||
|
||||
func DefaultState(sharedKey Key) State {
|
||||
|
@ -65,8 +69,9 @@ func DefaultState(sharedKey Key) State {
|
|||
RecvCh: kdfChain{CK: sharedKey, Crypto: c},
|
||||
MkSkipped: &KeysStorageInMemory{},
|
||||
MaxSkip: 1000,
|
||||
MaxKeep: 100,
|
||||
DeleteKeys: make(map[uint]Key),
|
||||
MaxMessageKeysPerSession: 2000,
|
||||
MaxKeep: 2000,
|
||||
KeysCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,6 +117,7 @@ type skippedKey struct {
|
|||
key Key
|
||||
nr uint
|
||||
mk Key
|
||||
seq uint
|
||||
}
|
||||
|
||||
// skipMessageKeys skips message keys in the current receiving chain.
|
||||
|
@ -119,14 +125,11 @@ func (s *State) skipMessageKeys(key Key, until uint) ([]skippedKey, error) {
|
|||
if until < uint(s.RecvCh.N) {
|
||||
return nil, fmt.Errorf("bad until: probably an out-of-order message that was deleted")
|
||||
}
|
||||
nSkipped, err := s.MkSkipped.Count(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if until-uint(s.RecvCh.N)+nSkipped > s.MaxSkip {
|
||||
if uint(s.RecvCh.N)+s.MaxSkip < until {
|
||||
return nil, fmt.Errorf("too many messages")
|
||||
}
|
||||
|
||||
skipped := []skippedKey{}
|
||||
for uint(s.RecvCh.N) < until {
|
||||
mk := s.RecvCh.step()
|
||||
|
@ -134,32 +137,31 @@ func (s *State) skipMessageKeys(key Key, until uint) ([]skippedKey, error) {
|
|||
key: key,
|
||||
nr: uint(s.RecvCh.N - 1),
|
||||
mk: mk,
|
||||
seq: s.KeysCount,
|
||||
})
|
||||
// Increment key count
|
||||
s.KeysCount++
|
||||
|
||||
}
|
||||
return skipped, nil
|
||||
}
|
||||
|
||||
func (s *State) applyChanges(sc State, skipped []skippedKey) error {
|
||||
func (s *State) applyChanges(sc State, sessionID []byte, skipped []skippedKey) error {
|
||||
*s = sc
|
||||
for _, skipped := range skipped {
|
||||
if err := s.MkSkipped.Put(skipped.key, skipped.nr, skipped.mk); err != nil {
|
||||
if err := s.MkSkipped.Put(sessionID, skipped.key, skipped.nr, skipped.mk, skipped.seq); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.MkSkipped.TruncateMks(sessionID, s.MaxMessageKeysPerSession); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.KeysCount >= s.MaxKeep {
|
||||
if err := s.MkSkipped.DeleteOldMks(sessionID, s.KeysCount-s.MaxKeep); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) deleteSkippedKeys(key Key) error {
|
||||
|
||||
s.DeleteKeys[s.Step] = key
|
||||
s.Step++
|
||||
if hk, ok := s.DeleteKeys[s.Step-s.MaxKeep]; ok {
|
||||
if err := s.MkSkipped.DeletePk(hk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(s.DeleteKeys, s.Step-s.MaxKeep)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue