mirror of
https://github.com/status-im/status-go.git
synced 2025-02-16 16:56:53 +00:00
Change handling of skipped/deleted keys & add version (#1261)
- Skipped keys The purpose of limiting the number of skipped keys generated is to avoid a dos attack whereby an attacker would send a large N, forcing the device to compute all the keys between currentN..N . Previously the logic for handling skipped keys was: - If in the current receiving chain there are more than maxSkip keys, throw an error This is problematic as in long-lived session dropped/unreceived messages starts piling up, eventually reaching the threshold (1000 dropped/unreceived messages). This logic has been changed to be more inline with signals spec, and now it is: - If N is > currentN + maxSkip, throw an error The purpose of limiting the number of skipped keys stored is to avoid a dos attack whereby an attacker would force us to store a large number of keys, filling up our storage. Previously the logic for handling old keys was: - Once you have maxKeep ratchet steps, delete any key from currentRatchet - maxKeep. This, in combination with the maxSkip implementation, capped the number of stored keys to maxSkip * maxKeep. The logic has been changed to: - Keep a maximum of MaxMessageKeysPerSession and additionally we delete any key that has a sequence number < currentSeqNum - maxKeep - Version We check now the version of the bundle so that when we get a bundle from the same installationID with a higher version, we mark the previous bundle as expired and use the new bundle the next time a message is sent
This commit is contained in:
parent
58bd36e79e
commit
ee3c05c79b
4
Gopkg.lock
generated
4
Gopkg.lock
generated
@ -782,11 +782,11 @@
|
|||||||
version = "v1.1"
|
version = "v1.1"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
digest = "1:b80bfc6278c27edcb1946fc38bc61f4c3a1cd50978083a4a79b7c875c70ed9d8"
|
digest = "1:41fb72d7a71f37f1f9c766d965178636ecda21b429b1f2e3fff42cfc31279751"
|
||||||
name = "github.com/status-im/doubleratchet"
|
name = "github.com/status-im/doubleratchet"
|
||||||
packages = ["."]
|
packages = ["."]
|
||||||
pruneopts = "NUT"
|
pruneopts = "NUT"
|
||||||
revision = "321788dbb6eac36f7dab04e631db139e13bb280b"
|
revision = "4dcb6cba284ae9f97129e2a98b9277f629d9dbc4"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
branch = "master"
|
branch = "master"
|
||||||
|
@ -160,7 +160,7 @@
|
|||||||
|
|
||||||
[[constraint]]
|
[[constraint]]
|
||||||
name = "github.com/status-im/doubleratchet"
|
name = "github.com/status-im/doubleratchet"
|
||||||
revision = "321788dbb6eac36f7dab04e631db139e13bb280b"
|
revision = "4dcb6cba284ae9f97129e2a98b9277f629d9dbc4"
|
||||||
|
|
||||||
[[constraint]]
|
[[constraint]]
|
||||||
name = "github.com/status-im/migrate"
|
name = "github.com/status-im/migrate"
|
||||||
|
@ -21,6 +21,15 @@ var ErrSessionNotFound = errors.New("session not found")
|
|||||||
// If we have no bundles, we use a constant so that the message can reach any device
|
// If we have no bundles, we use a constant so that the message can reach any device
|
||||||
const noInstallationID = "none"
|
const noInstallationID = "none"
|
||||||
|
|
||||||
|
// How many consecutive messages can be skipped in the receiving chain
|
||||||
|
const maxSkip = 1000
|
||||||
|
|
||||||
|
// Any message with seqNo <= currentSeq - maxKeep will be deleted
|
||||||
|
const maxKeep = 3000
|
||||||
|
|
||||||
|
// How many keys do we store in total per session
|
||||||
|
const maxMessageKeysPerSession = 2000
|
||||||
|
|
||||||
// EncryptionService defines a service that is responsible for the encryption aspect of the protocol
|
// EncryptionService defines a service that is responsible for the encryption aspect of the protocol
|
||||||
type EncryptionService struct {
|
type EncryptionService struct {
|
||||||
log log.Logger
|
log log.Logger
|
||||||
@ -251,9 +260,9 @@ func (s *EncryptionService) createNewSession(drInfo *RatchetInfo, sk [32]byte, k
|
|||||||
keyPair,
|
keyPair,
|
||||||
s.persistence.GetSessionStorage(),
|
s.persistence.GetSessionStorage(),
|
||||||
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
||||||
// TODO: Temporarily increase to a high number, until
|
dr.WithMaxSkip(maxSkip),
|
||||||
// we make sure it's a sliding window rather than dropping
|
dr.WithMaxKeep(maxKeep),
|
||||||
dr.WithMaxSkip(10000),
|
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
|
||||||
dr.WithCrypto(crypto.EthereumCrypto{}))
|
dr.WithCrypto(crypto.EthereumCrypto{}))
|
||||||
} else {
|
} else {
|
||||||
session, err = dr.NewWithRemoteKey(
|
session, err = dr.NewWithRemoteKey(
|
||||||
@ -262,9 +271,9 @@ func (s *EncryptionService) createNewSession(drInfo *RatchetInfo, sk [32]byte, k
|
|||||||
keyPair.PubKey,
|
keyPair.PubKey,
|
||||||
s.persistence.GetSessionStorage(),
|
s.persistence.GetSessionStorage(),
|
||||||
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
||||||
// TODO: Temporarily increase to a high number, until
|
dr.WithMaxSkip(maxSkip),
|
||||||
// we make sure it's a sliding window rather than dropping
|
dr.WithMaxKeep(maxKeep),
|
||||||
dr.WithMaxSkip(10000),
|
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
|
||||||
dr.WithCrypto(crypto.EthereumCrypto{}))
|
dr.WithCrypto(crypto.EthereumCrypto{}))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -291,9 +300,9 @@ func (s *EncryptionService) encryptUsingDR(theirIdentityKey *ecdsa.PublicKey, dr
|
|||||||
drInfo.ID,
|
drInfo.ID,
|
||||||
sessionStorage,
|
sessionStorage,
|
||||||
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
||||||
// TODO: Temporarily increase to a high number, until
|
dr.WithMaxSkip(maxSkip),
|
||||||
// we make sure it's a sliding window rather than dropping
|
dr.WithMaxKeep(maxKeep),
|
||||||
dr.WithMaxSkip(10000),
|
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
|
||||||
dr.WithCrypto(crypto.EthereumCrypto{}),
|
dr.WithCrypto(crypto.EthereumCrypto{}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -342,9 +351,9 @@ func (s *EncryptionService) decryptUsingDR(theirIdentityKey *ecdsa.PublicKey, dr
|
|||||||
drInfo.ID,
|
drInfo.ID,
|
||||||
sessionStorage,
|
sessionStorage,
|
||||||
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
dr.WithKeysStorage(s.persistence.GetKeysStorage()),
|
||||||
// TODO: Temporarily increase to a high number, until
|
dr.WithMaxSkip(maxSkip),
|
||||||
// we make sure it's a sliding window rather than dropping
|
dr.WithMaxKeep(maxKeep),
|
||||||
dr.WithMaxSkip(10000),
|
dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession),
|
||||||
dr.WithCrypto(crypto.EthereumCrypto{}),
|
dr.WithCrypto(crypto.EthereumCrypto{}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -314,6 +314,213 @@ func (s *EncryptionServiceTestSuite) TestConversation() {
|
|||||||
s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH")
|
s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Previous implementation allowed max maxSkip keys in the same receiving chain
|
||||||
|
// leading to a problem whereby dropped messages would accumulate and eventually
|
||||||
|
// we would not be able to decrypt any new message anymore.
|
||||||
|
// Here we are testing that maxSkip only applies to *consecutive* messages, not
|
||||||
|
// overall.
|
||||||
|
func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() {
|
||||||
|
bobText := []byte("text")
|
||||||
|
|
||||||
|
bobKey, err := crypto.GenerateKey()
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
aliceKey, err := crypto.GenerateKey()
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Create a bundle
|
||||||
|
bobBundle, err := s.bob.CreateBundle(bobKey)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// We add bob bundle
|
||||||
|
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Create a bundle
|
||||||
|
aliceBundle, err := s.alice.CreateBundle(aliceKey)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// We add alice bundle
|
||||||
|
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Bob sends a message
|
||||||
|
|
||||||
|
for i := 0; i < maxSkip; i++ {
|
||||||
|
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bob sends a message
|
||||||
|
bobMessage1, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Alice receives the message
|
||||||
|
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Bob sends a message
|
||||||
|
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Bob sends a message
|
||||||
|
bobMessage2, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Alice receives the message, we should have maxSkip + 1 keys in the db, but
|
||||||
|
// we should not throw an error
|
||||||
|
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that an error is thrown if max skip is reached
|
||||||
|
func (s *EncryptionServiceTestSuite) TestMaxSkipKeysError() {
|
||||||
|
bobText := []byte("text")
|
||||||
|
|
||||||
|
bobKey, err := crypto.GenerateKey()
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
aliceKey, err := crypto.GenerateKey()
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Create a bundle
|
||||||
|
bobBundle, err := s.bob.CreateBundle(bobKey)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// We add bob bundle
|
||||||
|
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Create a bundle
|
||||||
|
aliceBundle, err := s.alice.CreateBundle(aliceKey)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// We add alice bundle
|
||||||
|
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Bob sends a message
|
||||||
|
|
||||||
|
for i := 0; i < maxSkip+1; i++ {
|
||||||
|
_, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bob sends a message
|
||||||
|
bobMessage1, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Alice receives the message
|
||||||
|
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1)
|
||||||
|
s.Require().Equal(errors.New("can't skip current chain message keys: too many messages"), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() {
|
||||||
|
bobText := []byte("text")
|
||||||
|
|
||||||
|
bobKey, err := crypto.GenerateKey()
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
aliceKey, err := crypto.GenerateKey()
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Create a bundle
|
||||||
|
bobBundle, err := s.bob.CreateBundle(bobKey)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// We add bob bundle
|
||||||
|
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Create a bundle
|
||||||
|
aliceBundle, err := s.alice.CreateBundle(aliceKey)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// We add alice bundle
|
||||||
|
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// We create just enough messages so that the first key should be deleted
|
||||||
|
|
||||||
|
nMessages := maxMessageKeysPerSession + maxMessageKeysPerSession/maxSkip + 2
|
||||||
|
messages := make([]map[string]*DirectMessageProtocol, nMessages)
|
||||||
|
for i := 0; i < nMessages; i++ {
|
||||||
|
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
messages[i] = m
|
||||||
|
|
||||||
|
// We decrypt some messages otherwise we hit maxSkip limit
|
||||||
|
if i%maxSkip == 0 {
|
||||||
|
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Another message to trigger the deletion
|
||||||
|
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// We decrypt the first message, and it should fail
|
||||||
|
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1])
|
||||||
|
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
|
||||||
|
|
||||||
|
// We decrypt the second message, and it should be decrypted
|
||||||
|
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[2])
|
||||||
|
s.Require().NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *EncryptionServiceTestSuite) TestMaxKeep() {
|
||||||
|
bobText := []byte("text")
|
||||||
|
|
||||||
|
bobKey, err := crypto.GenerateKey()
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
aliceKey, err := crypto.GenerateKey()
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Create a bundle
|
||||||
|
bobBundle, err := s.bob.CreateBundle(bobKey)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// We add bob bundle
|
||||||
|
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Create a bundle
|
||||||
|
aliceBundle, err := s.alice.CreateBundle(aliceKey)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// We add alice bundle
|
||||||
|
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// We decrypt all messages but 1 & 2
|
||||||
|
messages := make([]map[string]*DirectMessageProtocol, maxKeep)
|
||||||
|
for i := 0; i < maxKeep; i++ {
|
||||||
|
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||||
|
messages[i] = m
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
if i != 0 && i != 1 {
|
||||||
|
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// We decrypt the first message, and it should fail, as it should have been removed
|
||||||
|
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[0])
|
||||||
|
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
|
||||||
|
|
||||||
|
// We decrypt the second message, and it should be decrypted
|
||||||
|
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1])
|
||||||
|
s.Require().NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
// Alice has Bob's bundle
|
// Alice has Bob's bundle
|
||||||
// Bob has Alice's bundle
|
// Bob has Alice's bundle
|
||||||
// Bob sends a message to alice
|
// Bob sends a message to alice
|
||||||
@ -557,6 +764,9 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
|
|||||||
|
|
||||||
bobBundle2, err := NewBundleContainer(bobKey, bobInstallationID)
|
bobBundle2, err := NewBundleContainer(bobKey, bobInstallationID)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
// We set the version
|
||||||
|
|
||||||
|
bobBundle2.GetBundle().GetSignedPreKeys()[bobInstallationID].Version = 1
|
||||||
|
|
||||||
err = SignBundle(bobKey, bobBundle2)
|
err = SignBundle(bobKey, bobBundle2)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
// 1536754952_initial_schema.up.sql
|
// 1536754952_initial_schema.up.sql
|
||||||
// 1539249977_update_ratchet_info.down.sql
|
// 1539249977_update_ratchet_info.down.sql
|
||||||
// 1539249977_update_ratchet_info.up.sql
|
// 1539249977_update_ratchet_info.up.sql
|
||||||
|
// 1540715431_add_version.down.sql
|
||||||
|
// 1540715431_add_version.up.sql
|
||||||
// static.go
|
// static.go
|
||||||
// DO NOT EDIT!
|
// DO NOT EDIT!
|
||||||
|
|
||||||
@ -87,7 +89,7 @@ func _1536754952_initial_schemaDownSql() (*asset, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
info := bindataFileInfo{name: "1536754952_initial_schema.down.sql", size: 83, mode: os.FileMode(420), modTime: time.Unix(1537862328, 0)}
|
info := bindataFileInfo{name: "1536754952_initial_schema.down.sql", size: 83, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||||
a := &asset{bytes: bytes, info: info}
|
a := &asset{bytes: bytes, info: info}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
@ -107,7 +109,7 @@ func _1536754952_initial_schemaUpSql() (*asset, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
info := bindataFileInfo{name: "1536754952_initial_schema.up.sql", size: 962, mode: os.FileMode(420), modTime: time.Unix(1539252806, 0)}
|
info := bindataFileInfo{name: "1536754952_initial_schema.up.sql", size: 962, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||||
a := &asset{bytes: bytes, info: info}
|
a := &asset{bytes: bytes, info: info}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
@ -127,7 +129,7 @@ func _1539249977_update_ratchet_infoDownSql() (*asset, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
info := bindataFileInfo{name: "1539249977_update_ratchet_info.down.sql", size: 311, mode: os.FileMode(420), modTime: time.Unix(1539250187, 0)}
|
info := bindataFileInfo{name: "1539249977_update_ratchet_info.down.sql", size: 311, mode: os.FileMode(420), modTime: time.Unix(1540738831, 0)}
|
||||||
a := &asset{bytes: bytes, info: info}
|
a := &asset{bytes: bytes, info: info}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
@ -147,7 +149,47 @@ func _1539249977_update_ratchet_infoUpSql() (*asset, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
info := bindataFileInfo{name: "1539249977_update_ratchet_info.up.sql", size: 368, mode: os.FileMode(420), modTime: time.Unix(1539250201, 0)}
|
info := bindataFileInfo{name: "1539249977_update_ratchet_info.up.sql", size: 368, mode: os.FileMode(420), modTime: time.Unix(1540738831, 0)}
|
||||||
|
a := &asset{bytes: bytes, info: info}
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var __1540715431_add_versionDownSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x72\xf4\x09\x71\x0d\x52\x08\x71\x74\xf2\x71\x55\xc8\x4e\xad\x2c\x56\x70\x09\xf2\x0f\x50\x70\xf6\xf7\x09\xf5\xf5\x53\x28\x4e\x2d\x2e\xce\xcc\xcf\x8b\xcf\x4c\xb1\xe6\x42\x56\x08\x15\x47\x55\x0c\xd2\x1d\x9f\x9c\x5f\x9a\x57\x82\xaa\x38\xa9\x34\x2f\x25\x27\x15\x55\x6d\x59\x6a\x11\xc8\x00\x6b\x2e\x40\x00\x00\x00\xff\xff\xda\x5d\x80\x2d\x7f\x00\x00\x00")
|
||||||
|
|
||||||
|
func _1540715431_add_versionDownSqlBytes() ([]byte, error) {
|
||||||
|
return bindataRead(
|
||||||
|
__1540715431_add_versionDownSql,
|
||||||
|
"1540715431_add_version.down.sql",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _1540715431_add_versionDownSql() (*asset, error) {
|
||||||
|
bytes, err := _1540715431_add_versionDownSqlBytes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
info := bindataFileInfo{name: "1540715431_add_version.down.sql", size: 127, mode: os.FileMode(420), modTime: time.Unix(1540989119, 0)}
|
||||||
|
a := &asset{bytes: bytes, info: info}
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var __1540715431_add_versionUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x8c\xcd\xb1\x0e\x02\x21\x0c\xc6\xf1\xdd\xa7\xf8\x1e\xc1\xdd\x09\xa4\x67\x4c\x7a\x90\x90\x32\x93\xe8\x31\x5c\x54\x2e\x8a\x98\xf8\xf6\x06\xe3\xc2\xa2\xae\x6d\xff\xbf\x1a\x62\x12\xc2\xe0\xdd\x88\x53\x7a\x96\xcd\x4a\xb1\x90\x87\x28\xcd\xf4\x9e\x40\x19\x83\xad\xe3\x30\x5a\x94\x74\x8d\xb9\x5e\xb0\xb7\x42\x3b\xf2\xb0\x4e\x60\x03\x33\x0c\x0d\x2a\xb0\x60\xfd\xab\x2f\x65\x5e\x72\x9c\x27\x68\x76\xba\x3f\xfe\x2c\xbb\xa0\x01\xf1\xb8\xd4\x7c\xff\xfb\xe7\xa1\xe6\xe9\x9c\x3a\xe5\x91\x6e\x4d\xfe\x4a\xbc\x02\x00\x00\xff\xff\x0e\x27\x2c\x52\x09\x01\x00\x00")
|
||||||
|
|
||||||
|
func _1540715431_add_versionUpSqlBytes() ([]byte, error) {
|
||||||
|
return bindataRead(
|
||||||
|
__1540715431_add_versionUpSql,
|
||||||
|
"1540715431_add_version.up.sql",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _1540715431_add_versionUpSql() (*asset, error) {
|
||||||
|
bytes, err := _1540715431_add_versionUpSqlBytes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
info := bindataFileInfo{name: "1540715431_add_version.up.sql", size: 265, mode: os.FileMode(420), modTime: time.Unix(1540989075, 0)}
|
||||||
a := &asset{bytes: bytes, info: info}
|
a := &asset{bytes: bytes, info: info}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
@ -167,7 +209,7 @@ func staticGo() (*asset, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
info := bindataFileInfo{name: "static.go", size: 188, mode: os.FileMode(420), modTime: time.Unix(1537862328, 0)}
|
info := bindataFileInfo{name: "static.go", size: 188, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||||
a := &asset{bytes: bytes, info: info}
|
a := &asset{bytes: bytes, info: info}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
@ -228,6 +270,8 @@ var _bindata = map[string]func() (*asset, error){
|
|||||||
"1536754952_initial_schema.up.sql": _1536754952_initial_schemaUpSql,
|
"1536754952_initial_schema.up.sql": _1536754952_initial_schemaUpSql,
|
||||||
"1539249977_update_ratchet_info.down.sql": _1539249977_update_ratchet_infoDownSql,
|
"1539249977_update_ratchet_info.down.sql": _1539249977_update_ratchet_infoDownSql,
|
||||||
"1539249977_update_ratchet_info.up.sql": _1539249977_update_ratchet_infoUpSql,
|
"1539249977_update_ratchet_info.up.sql": _1539249977_update_ratchet_infoUpSql,
|
||||||
|
"1540715431_add_version.down.sql": _1540715431_add_versionDownSql,
|
||||||
|
"1540715431_add_version.up.sql": _1540715431_add_versionUpSql,
|
||||||
"static.go": staticGo,
|
"static.go": staticGo,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -275,6 +319,8 @@ var _bintree = &bintree{nil, map[string]*bintree{
|
|||||||
"1536754952_initial_schema.up.sql": &bintree{_1536754952_initial_schemaUpSql, map[string]*bintree{}},
|
"1536754952_initial_schema.up.sql": &bintree{_1536754952_initial_schemaUpSql, map[string]*bintree{}},
|
||||||
"1539249977_update_ratchet_info.down.sql": &bintree{_1539249977_update_ratchet_infoDownSql, map[string]*bintree{}},
|
"1539249977_update_ratchet_info.down.sql": &bintree{_1539249977_update_ratchet_infoDownSql, map[string]*bintree{}},
|
||||||
"1539249977_update_ratchet_info.up.sql": &bintree{_1539249977_update_ratchet_infoUpSql, map[string]*bintree{}},
|
"1539249977_update_ratchet_info.up.sql": &bintree{_1539249977_update_ratchet_infoUpSql, map[string]*bintree{}},
|
||||||
|
"1540715431_add_version.down.sql": &bintree{_1540715431_add_versionDownSql, map[string]*bintree{}},
|
||||||
|
"1540715431_add_version.up.sql": &bintree{_1540715431_add_versionUpSql, map[string]*bintree{}},
|
||||||
"static.go": &bintree{staticGo, map[string]*bintree{}},
|
"static.go": &bintree{staticGo, map[string]*bintree{}},
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
@ -61,8 +61,7 @@ func (p *ProtocolService) BuildPublicMessage(myIdentityKey *ecdsa.PrivateKey, pa
|
|||||||
// BuildDirectMessage marshals a 1:1 chat message given the user identity private key, the recipient's public key, and a payload
|
// BuildDirectMessage marshals a 1:1 chat message given the user identity private key, the recipient's public key, and a payload
|
||||||
func (p *ProtocolService) BuildDirectMessage(myIdentityKey *ecdsa.PrivateKey, payload []byte, theirPublicKeys ...*ecdsa.PublicKey) (map[*ecdsa.PublicKey][]byte, error) {
|
func (p *ProtocolService) BuildDirectMessage(myIdentityKey *ecdsa.PrivateKey, payload []byte, theirPublicKeys ...*ecdsa.PublicKey) (map[*ecdsa.PublicKey][]byte, error) {
|
||||||
response := make(map[*ecdsa.PublicKey][]byte)
|
response := make(map[*ecdsa.PublicKey][]byte)
|
||||||
publicKeys := append(theirPublicKeys, &myIdentityKey.PublicKey)
|
for _, publicKey := range theirPublicKeys {
|
||||||
for _, publicKey := range publicKeys {
|
|
||||||
// Encrypt payload
|
// Encrypt payload
|
||||||
encryptionResponse, err := p.encryption.EncryptPayload(publicKey, myIdentityKey, payload)
|
encryptionResponse, err := p.encryption.EncryptPayload(publicKey, myIdentityKey, payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -56,7 +56,7 @@ func (s *ProtocolServiceTestSuite) TestBuildDirectMessage() {
|
|||||||
})
|
})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
|
|
||||||
marshaledMsg, err := s.alice.BuildDirectMessage(aliceKey, payload, &bobKey.PublicKey)
|
marshaledMsg, err := s.alice.BuildDirectMessage(aliceKey, payload, &bobKey.PublicKey, &aliceKey.PublicKey)
|
||||||
|
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.NotNil(marshaledMsg, "It creates a message")
|
s.NotNil(marshaledMsg, "It creates a message")
|
||||||
|
@ -16,6 +16,9 @@ import (
|
|||||||
"github.com/status-im/status-go/services/shhext/chat/migrations"
|
"github.com/status-im/status-go/services/shhext/chat/migrations"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// A safe max number of rows
|
||||||
|
const maxNumberOfRows = 100000000
|
||||||
|
|
||||||
// SQLLitePersistence represents a persistence service tied to an SQLite database
|
// SQLLitePersistence represents a persistence service tied to an SQLite database
|
||||||
type SQLLitePersistence struct {
|
type SQLLitePersistence struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
@ -107,7 +110,20 @@ func (s *SQLLitePersistence) AddPrivateBundle(b *BundleContainer) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for installationID, signedPreKey := range b.GetBundle().GetSignedPreKeys() {
|
for installationID, signedPreKey := range b.GetBundle().GetSignedPreKeys() {
|
||||||
stmt, err := tx.Prepare("INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, timestamp) VALUES(?, ?, ?, ?, ?)")
|
var version uint32
|
||||||
|
stmt, err := tx.Prepare("SELECT version FROM bundles WHERE installation_id = ? AND identity = ? ORDER BY version DESC LIMIT 1")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
err = stmt.QueryRow(installationID, b.GetBundle().GetIdentity()).Scan(&version)
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, err = tx.Prepare("INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, version, timestamp) VALUES(?, ?, ?, ?, ?, ?)")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -118,6 +134,7 @@ func (s *SQLLitePersistence) AddPrivateBundle(b *BundleContainer) error {
|
|||||||
b.GetPrivateSignedPreKey(),
|
b.GetPrivateSignedPreKey(),
|
||||||
signedPreKey.GetSignedPreKey(),
|
signedPreKey.GetSignedPreKey(),
|
||||||
installationID,
|
installationID,
|
||||||
|
version+1,
|
||||||
time.Now().UnixNano(),
|
time.Now().UnixNano(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -144,15 +161,18 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
|
|||||||
|
|
||||||
for installationID, signedPreKeyContainer := range b.GetSignedPreKeys() {
|
for installationID, signedPreKeyContainer := range b.GetSignedPreKeys() {
|
||||||
signedPreKey := signedPreKeyContainer.GetSignedPreKey()
|
signedPreKey := signedPreKeyContainer.GetSignedPreKey()
|
||||||
insertStmt, err := tx.Prepare("INSERT INTO bundles(identity, signed_pre_key, installation_id, timestamp) VALUES( ?, ?, ?, ?)")
|
version := signedPreKeyContainer.GetVersion()
|
||||||
|
insertStmt, err := tx.Prepare("INSERT INTO bundles(identity, signed_pre_key, installation_id, version, timestamp) VALUES( ?, ?, ?, ?, ?)")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer insertStmt.Close()
|
defer insertStmt.Close()
|
||||||
|
|
||||||
_, err = insertStmt.Exec(
|
_, err = insertStmt.Exec(
|
||||||
b.GetIdentity(),
|
b.GetIdentity(),
|
||||||
signedPreKey,
|
signedPreKey,
|
||||||
installationID,
|
installationID,
|
||||||
|
version,
|
||||||
time.Now().UnixNano(),
|
time.Now().UnixNano(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -160,7 +180,7 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Mark old bundles as expired
|
// Mark old bundles as expired
|
||||||
updateStmt, err := tx.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND installation_id = ? AND signed_pre_key != ?")
|
updateStmt, err := tx.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND installation_id = ? AND version < ?")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -169,7 +189,7 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error {
|
|||||||
_, err = updateStmt.Exec(
|
_, err = updateStmt.Exec(
|
||||||
b.GetIdentity(),
|
b.GetIdentity(),
|
||||||
installationID,
|
installationID,
|
||||||
signedPreKey,
|
version,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
@ -281,7 +301,7 @@ func (s *SQLLitePersistence) MarkBundleExpired(identity []byte) error {
|
|||||||
func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey) (*Bundle, error) {
|
func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey) (*Bundle, error) {
|
||||||
|
|
||||||
identity := crypto.CompressPubkey(publicKey)
|
identity := crypto.CompressPubkey(publicKey)
|
||||||
stmt, err := s.db.Prepare("SELECT signed_pre_key,installation_id FROM bundles WHERE expired = 0 AND identity = ? ORDER BY timestamp DESC")
|
stmt, err := s.db.Prepare("SELECT signed_pre_key,installation_id, version FROM bundles WHERE expired = 0 AND identity = ? ORDER BY version DESC")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -304,16 +324,21 @@ func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey) (*Bundl
|
|||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var signedPreKey []byte
|
var signedPreKey []byte
|
||||||
var installationID string
|
var installationID string
|
||||||
|
var version uint32
|
||||||
rowCount++
|
rowCount++
|
||||||
err = rows.Scan(
|
err = rows.Scan(
|
||||||
&signedPreKey,
|
&signedPreKey,
|
||||||
&installationID,
|
&installationID,
|
||||||
|
&version,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
bundle.SignedPreKeys[installationID] = &SignedPreKey{SignedPreKey: signedPreKey}
|
bundle.SignedPreKeys[installationID] = &SignedPreKey{
|
||||||
|
SignedPreKey: signedPreKey,
|
||||||
|
Version: version,
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -448,17 +473,53 @@ func (s *SQLLiteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Put stores a key with the specified public key, message number and message key
|
// Put stores a key with the specified public key, message number and message key
|
||||||
func (s *SQLLiteKeysStorage) Put(pubKey dr.Key, msgNum uint, mk dr.Key) error {
|
func (s *SQLLiteKeysStorage) Put(sessionID []byte, pubKey dr.Key, msgNum uint, mk dr.Key, seqNum uint) error {
|
||||||
stmt, err := s.db.Prepare("insert into keys(public_key, msg_num, message_key) values(?, ?, ?)")
|
stmt, err := s.db.Prepare("insert into keys(session_id, public_key, msg_num, message_key, seq_num) values(?, ?, ?, ?, ?)")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
|
|
||||||
_, err = stmt.Exec(
|
_, err = stmt.Exec(
|
||||||
|
sessionID,
|
||||||
pubKey[:],
|
pubKey[:],
|
||||||
msgNum,
|
msgNum,
|
||||||
mk[:],
|
mk[:],
|
||||||
|
seqNum,
|
||||||
|
)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteOldMks caps remove any key < seq_num, included
|
||||||
|
func (s *SQLLiteKeysStorage) DeleteOldMks(sessionID []byte, deleteUntil uint) error {
|
||||||
|
stmt, err := s.db.Prepare("DELETE FROM keys WHERE session_id = ? AND seq_num <= ?")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
_, err = stmt.Exec(
|
||||||
|
sessionID,
|
||||||
|
deleteUntil,
|
||||||
|
)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TruncateMks caps the number of keys to maxKeysPerSession deleting them in FIFO fashion
|
||||||
|
func (s *SQLLiteKeysStorage) TruncateMks(sessionID []byte, maxKeysPerSession int) error {
|
||||||
|
stmt, err := s.db.Prepare("DELETE FROM keys WHERE rowid IN (SELECT rowid FROM keys WHERE session_id = ? ORDER BY seq_num DESC LIMIT ? OFFSET ?)")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
_, err = stmt.Exec(
|
||||||
|
sessionID,
|
||||||
|
// We LIMIT to the max number of rows here, as OFFSET can't be used without a LIMIT
|
||||||
|
maxNumberOfRows,
|
||||||
|
maxKeysPerSession,
|
||||||
)
|
)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@ -480,21 +541,6 @@ func (s *SQLLiteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePk deletes the keys with the specified public key
|
|
||||||
func (s *SQLLiteKeysStorage) DeletePk(pubKey dr.Key) error {
|
|
||||||
stmt, err := s.db.Prepare("DELETE FROM keys WHERE public_key = ?")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
|
|
||||||
_, err = stmt.Exec(
|
|
||||||
pubKey[:],
|
|
||||||
)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count returns the count of keys with the specified public key
|
// Count returns the count of keys with the specified public key
|
||||||
func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) {
|
func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) {
|
||||||
stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys WHERE public_key = ?")
|
stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys WHERE public_key = ?")
|
||||||
@ -512,6 +558,23 @@ func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) {
|
|||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountAll returns the count of keys with the specified public key
|
||||||
|
func (s *SQLLiteKeysStorage) CountAll() (uint, error) {
|
||||||
|
stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys")
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
var count uint
|
||||||
|
err = stmt.QueryRow().Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
// All returns nil
|
// All returns nil
|
||||||
func (s *SQLLiteKeysStorage) All() (map[dr.Key]map[uint]dr.Key, error) {
|
func (s *SQLLiteKeysStorage) All() (map[dr.Key]map[uint]dr.Key, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -525,6 +588,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
|
|||||||
dhsPrivate := dhs.PrivateKey()
|
dhsPrivate := dhs.PrivateKey()
|
||||||
pn := state.PN
|
pn := state.PN
|
||||||
step := state.Step
|
step := state.Step
|
||||||
|
keysCount := state.KeysCount
|
||||||
|
|
||||||
rootChainKey := state.RootCh.CK[:]
|
rootChainKey := state.RootCh.CK[:]
|
||||||
|
|
||||||
@ -534,7 +598,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
|
|||||||
recvChainKey := state.RecvCh.CK[:]
|
recvChainKey := state.RecvCh.CK[:]
|
||||||
recvChainN := state.RecvCh.N
|
recvChainN := state.RecvCh.N
|
||||||
|
|
||||||
stmt, err := s.db.Prepare("insert into sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
|
stmt, err := s.db.Prepare("insert into sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -552,6 +616,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
|
|||||||
recvChainN,
|
recvChainN,
|
||||||
pn,
|
pn,
|
||||||
step,
|
step,
|
||||||
|
keysCount,
|
||||||
)
|
)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@ -559,7 +624,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error {
|
|||||||
|
|
||||||
// Load retrieves the double ratchet state for a given ID
|
// Load retrieves the double ratchet state for a given ID
|
||||||
func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
|
func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
|
||||||
stmt, err := s.db.Prepare("SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step FROM sessions WHERE id = ?")
|
stmt, err := s.db.Prepare("SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count FROM sessions WHERE id = ?")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -577,6 +642,7 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
|
|||||||
recvChainN uint
|
recvChainN uint
|
||||||
pn uint
|
pn uint
|
||||||
step uint
|
step uint
|
||||||
|
keysCount uint
|
||||||
)
|
)
|
||||||
|
|
||||||
err = stmt.QueryRow(id).Scan(
|
err = stmt.QueryRow(id).Scan(
|
||||||
@ -590,6 +656,7 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
|
|||||||
&recvChainN,
|
&recvChainN,
|
||||||
&pn,
|
&pn,
|
||||||
&step,
|
&step,
|
||||||
|
&keysCount,
|
||||||
)
|
)
|
||||||
switch err {
|
switch err {
|
||||||
case sql.ErrNoRows:
|
case sql.ErrNoRows:
|
||||||
@ -599,6 +666,7 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) {
|
|||||||
|
|
||||||
state.PN = uint32(pn)
|
state.PN = uint32(pn)
|
||||||
state.Step = step
|
state.Step = step
|
||||||
|
state.KeysCount = keysCount
|
||||||
|
|
||||||
state.DHs = ecrypto.DHPair{
|
state.DHs = ecrypto.DHPair{
|
||||||
PrvKey: toKey(dhsPrivate),
|
PrvKey: toKey(dhsPrivate),
|
||||||
|
@ -10,8 +10,11 @@ import (
|
|||||||
var (
|
var (
|
||||||
pubKey1 = dr.Key{0xe3, 0xbe, 0xb9, 0x4e, 0x70, 0x17, 0x37, 0xc, 0x1, 0x8f, 0xa9, 0x7e, 0xef, 0x4, 0xfb, 0x23, 0xac, 0xea, 0x28, 0xf7, 0xa9, 0x56, 0xcc, 0x1d, 0x46, 0xf3, 0xb5, 0x1d, 0x7d, 0x7d, 0x5e, 0x2c}
|
pubKey1 = dr.Key{0xe3, 0xbe, 0xb9, 0x4e, 0x70, 0x17, 0x37, 0xc, 0x1, 0x8f, 0xa9, 0x7e, 0xef, 0x4, 0xfb, 0x23, 0xac, 0xea, 0x28, 0xf7, 0xa9, 0x56, 0xcc, 0x1d, 0x46, 0xf3, 0xb5, 0x1d, 0x7d, 0x7d, 0x5e, 0x2c}
|
||||||
pubKey2 = dr.Key{0xec, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
pubKey2 = dr.Key{0xec, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||||
mk1 = dr.Key{0xeb, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
mk1 = dr.Key{0x00, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||||
mk2 = dr.Key{0xed, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
mk2 = dr.Key{0x01, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||||
|
mk3 = dr.Key{0x02, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||||
|
mk4 = dr.Key{0x03, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||||
|
mk5 = dr.Key{0x04, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSQLLitePersistenceKeysStorageTestSuite(t *testing.T) {
|
func TestSQLLitePersistenceKeysStorageTestSuite(t *testing.T) {
|
||||||
@ -47,10 +50,84 @@ func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLiteGetMissin
|
|||||||
|
|
||||||
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Put() {
|
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Put() {
|
||||||
// Act and assert.
|
// Act and assert.
|
||||||
err := s.service.Put(pubKey1, 0, mk1)
|
err := s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1)
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_DeleteOldMks() {
|
||||||
|
// Insert keys out-of-order
|
||||||
|
err := s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1)
|
||||||
|
s.NoError(err)
|
||||||
|
err = s.service.Put([]byte("session-id"), pubKey1, 1, mk2, 2)
|
||||||
|
s.NoError(err)
|
||||||
|
err = s.service.Put([]byte("session-id"), pubKey1, 2, mk3, 20)
|
||||||
|
s.NoError(err)
|
||||||
|
err = s.service.Put([]byte("session-id"), pubKey1, 3, mk4, 21)
|
||||||
|
s.NoError(err)
|
||||||
|
err = s.service.Put([]byte("session-id"), pubKey1, 4, mk5, 22)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
err = s.service.DeleteOldMks([]byte("session-id"), 20)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
_, ok, err := s.service.Get(pubKey1, 0)
|
||||||
|
s.NoError(err)
|
||||||
|
s.False(ok)
|
||||||
|
|
||||||
|
_, ok, err = s.service.Get(pubKey1, 1)
|
||||||
|
s.NoError(err)
|
||||||
|
s.False(ok)
|
||||||
|
|
||||||
|
_, ok, err = s.service.Get(pubKey1, 2)
|
||||||
|
s.NoError(err)
|
||||||
|
s.False(ok)
|
||||||
|
|
||||||
|
_, ok, err = s.service.Get(pubKey1, 3)
|
||||||
|
s.NoError(err)
|
||||||
|
s.True(ok)
|
||||||
|
|
||||||
|
_, ok, err = s.service.Get(pubKey1, 4)
|
||||||
|
s.NoError(err)
|
||||||
|
s.True(ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_TruncateMks() {
|
||||||
|
// Insert keys out-of-order
|
||||||
|
err := s.service.Put([]byte("session-id"), pubKey2, 2, mk5, 5)
|
||||||
|
s.NoError(err)
|
||||||
|
err = s.service.Put([]byte("session-id"), pubKey2, 0, mk3, 3)
|
||||||
|
s.NoError(err)
|
||||||
|
err = s.service.Put([]byte("session-id"), pubKey1, 1, mk2, 2)
|
||||||
|
s.NoError(err)
|
||||||
|
err = s.service.Put([]byte("session-id"), pubKey2, 1, mk4, 4)
|
||||||
|
s.NoError(err)
|
||||||
|
err = s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
err = s.service.TruncateMks([]byte("session-id"), 2)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
_, ok, err := s.service.Get(pubKey1, 0)
|
||||||
|
s.NoError(err)
|
||||||
|
s.False(ok)
|
||||||
|
|
||||||
|
_, ok, err = s.service.Get(pubKey1, 1)
|
||||||
|
s.NoError(err)
|
||||||
|
s.False(ok)
|
||||||
|
|
||||||
|
_, ok, err = s.service.Get(pubKey2, 0)
|
||||||
|
s.NoError(err)
|
||||||
|
s.False(ok)
|
||||||
|
|
||||||
|
_, ok, err = s.service.Get(pubKey2, 1)
|
||||||
|
s.NoError(err)
|
||||||
|
s.True(ok)
|
||||||
|
|
||||||
|
_, ok, err = s.service.Get(pubKey2, 2)
|
||||||
|
s.NoError(err)
|
||||||
|
s.True(ok)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Count() {
|
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Count() {
|
||||||
|
|
||||||
// Act.
|
// Act.
|
||||||
@ -71,12 +148,8 @@ func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Delete()
|
|||||||
|
|
||||||
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Flow() {
|
func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Flow() {
|
||||||
|
|
||||||
// Act and assert.
|
|
||||||
err := s.service.DeletePk(pubKey1)
|
|
||||||
s.NoError(err)
|
|
||||||
|
|
||||||
// Act.
|
// Act.
|
||||||
err = s.service.Put(pubKey1, 0, mk1)
|
err := s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1)
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
|
|
||||||
k, ok, err := s.service.Get(pubKey1, 0)
|
k, ok, err := s.service.Get(pubKey1, 0)
|
||||||
@ -124,30 +197,4 @@ func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Flow() {
|
|||||||
// Assert.
|
// Assert.
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.EqualValues(0, cnt)
|
s.EqualValues(0, cnt)
|
||||||
|
|
||||||
// Act.
|
|
||||||
err = s.service.Put(pubKey1, 0, mk1)
|
|
||||||
s.NoError(err)
|
|
||||||
|
|
||||||
err = s.service.Put(pubKey2, 0, mk2)
|
|
||||||
s.NoError(err)
|
|
||||||
|
|
||||||
err = s.service.DeletePk(pubKey1)
|
|
||||||
s.NoError(err)
|
|
||||||
|
|
||||||
err = s.service.DeletePk(pubKey1)
|
|
||||||
s.NoError(err)
|
|
||||||
|
|
||||||
err = s.service.DeletePk(pubKey2)
|
|
||||||
s.NoError(err)
|
|
||||||
|
|
||||||
cn1, err := s.service.Count(pubKey1)
|
|
||||||
s.NoError(err)
|
|
||||||
|
|
||||||
cn2, err := s.service.Count(pubKey2)
|
|
||||||
s.NoError(err)
|
|
||||||
|
|
||||||
// Assert.
|
|
||||||
s.Empty(cn1)
|
|
||||||
s.Empty(cn2)
|
|
||||||
}
|
}
|
||||||
|
@ -51,7 +51,7 @@ func (s *SQLLitePersistenceTestSuite) TestPrivateBundle() {
|
|||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
actualKey, err := s.service.GetPrivateKeyBundle([]byte("non-existing"))
|
actualKey, err := s.service.GetPrivateKeyBundle([]byte("non-existing"))
|
||||||
s.Require().NoError(err, "It does not return an error if the bundle is not there")
|
s.Require().NoError(err, "Error was not returned even though bundle is not there")
|
||||||
s.Nil(actualKey)
|
s.Nil(actualKey)
|
||||||
|
|
||||||
anyPrivateBundle, err := s.service.GetAnyPrivateBundle([]byte("non-existing-id"))
|
anyPrivateBundle, err := s.service.GetAnyPrivateBundle([]byte("non-existing-id"))
|
||||||
@ -82,7 +82,7 @@ func (s *SQLLitePersistenceTestSuite) TestPublicBundle() {
|
|||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
|
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
|
||||||
s.Require().NoError(err, "It does not return an error if the bundle is not there")
|
s.Require().NoError(err, "Error was not returned even though bundle is not there")
|
||||||
s.Nil(actualBundle)
|
s.Nil(actualBundle)
|
||||||
|
|
||||||
bundleContainer, err := NewBundleContainer(key, "1")
|
bundleContainer, err := NewBundleContainer(key, "1")
|
||||||
@ -98,12 +98,82 @@ func (s *SQLLitePersistenceTestSuite) TestPublicBundle() {
|
|||||||
s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys")
|
s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SQLLitePersistenceTestSuite) TestUpdatedBundle() {
|
||||||
|
key, err := crypto.GenerateKey()
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
|
||||||
|
s.Require().NoError(err, "Error was not returned even though bundle is not there")
|
||||||
|
s.Nil(actualBundle)
|
||||||
|
|
||||||
|
// Create & add initial bundle
|
||||||
|
bundleContainer, err := NewBundleContainer(key, "1")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
bundle := bundleContainer.GetBundle()
|
||||||
|
err = s.service.AddPublicBundle(bundle)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Create & add a new bundle
|
||||||
|
bundleContainer, err = NewBundleContainer(key, "1")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
bundle = bundleContainer.GetBundle()
|
||||||
|
// We set the version
|
||||||
|
bundle.GetSignedPreKeys()["1"].Version = 1
|
||||||
|
|
||||||
|
err = s.service.AddPublicBundle(bundle)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
actualBundle, err = s.service.GetPublicBundle(&key.PublicKey)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Equal(bundle.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity")
|
||||||
|
s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLLitePersistenceTestSuite) TestOutOfOrderBundles() {
|
||||||
|
key, err := crypto.GenerateKey()
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
|
||||||
|
s.Require().NoError(err, "Error was not returned even though bundle is not there")
|
||||||
|
s.Nil(actualBundle)
|
||||||
|
|
||||||
|
// Create & add initial bundle
|
||||||
|
bundleContainer, err := NewBundleContainer(key, "1")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
bundle1 := bundleContainer.GetBundle()
|
||||||
|
err = s.service.AddPublicBundle(bundle1)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Create & add a new bundle
|
||||||
|
bundleContainer, err = NewBundleContainer(key, "1")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
bundle2 := bundleContainer.GetBundle()
|
||||||
|
// We set the version
|
||||||
|
bundle2.GetSignedPreKeys()["1"].Version = 1
|
||||||
|
|
||||||
|
err = s.service.AddPublicBundle(bundle2)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Add again the initial bundle
|
||||||
|
err = s.service.AddPublicBundle(bundle1)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
actualBundle, err = s.service.GetPublicBundle(&key.PublicKey)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Equal(bundle2.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity")
|
||||||
|
s.Equal(bundle2.GetSignedPreKeys()["1"].GetVersion(), uint32(1))
|
||||||
|
s.Equal(bundle2.GetSignedPreKeys()["1"].GetSignedPreKey(), actualBundle.GetSignedPreKeys()["1"].GetSignedPreKey(), "It sets the right prekeys")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SQLLitePersistenceTestSuite) TestMultiplePublicBundle() {
|
func (s *SQLLitePersistenceTestSuite) TestMultiplePublicBundle() {
|
||||||
key, err := crypto.GenerateKey()
|
key, err := crypto.GenerateKey()
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
|
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
|
||||||
s.Require().NoError(err, "It does not return an error if the bundle is not there")
|
s.Require().NoError(err, "Error was not returned even though bundle is not there")
|
||||||
s.Nil(actualBundle)
|
s.Nil(actualBundle)
|
||||||
|
|
||||||
bundleContainer, err := NewBundleContainer(key, "1")
|
bundleContainer, err := NewBundleContainer(key, "1")
|
||||||
@ -120,8 +190,10 @@ func (s *SQLLitePersistenceTestSuite) TestMultiplePublicBundle() {
|
|||||||
// Adding a different bundle
|
// Adding a different bundle
|
||||||
bundleContainer, err = NewBundleContainer(key, "1")
|
bundleContainer, err = NewBundleContainer(key, "1")
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
// We set the version
|
||||||
bundle = bundleContainer.GetBundle()
|
bundle = bundleContainer.GetBundle()
|
||||||
|
bundle.GetSignedPreKeys()["1"].Version = 1
|
||||||
|
|
||||||
err = s.service.AddPublicBundle(bundle)
|
err = s.service.AddPublicBundle(bundle)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
@ -139,7 +211,7 @@ func (s *SQLLitePersistenceTestSuite) TestMultiDevicePublicBundle() {
|
|||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
|
actualBundle, err := s.service.GetPublicBundle(&key.PublicKey)
|
||||||
s.Require().NoError(err, "It does not return an error if the bundle is not there")
|
s.Require().NoError(err, "Error was not returned even though bundle is not there")
|
||||||
s.Nil(actualBundle)
|
s.Nil(actualBundle)
|
||||||
|
|
||||||
bundleContainer, err := NewBundleContainer(key, "1")
|
bundleContainer, err := NewBundleContainer(key, "1")
|
||||||
@ -273,4 +345,3 @@ func (s *SQLLitePersistenceTestSuite) TestRatchetInfoNoBundle() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Add test for MarkBundleExpired
|
// TODO: Add test for MarkBundleExpired
|
||||||
// TODO: Add test for AddPublicBundle checking that it expires previous bundles
|
|
||||||
|
@ -88,7 +88,7 @@ func ConfigCliFleetEthBetaJson() (*asset, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
info := bindataFileInfo{name: "../config/cli/fleet-eth.beta.json", size: 3237, mode: os.FileMode(420), modTime: time.Unix(1537514234, 0)}
|
info := bindataFileInfo{name: "../config/cli/fleet-eth.beta.json", size: 3237, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||||
a := &asset{bytes: bytes, info: info}
|
a := &asset{bytes: bytes, info: info}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
@ -108,7 +108,7 @@ func ConfigCliFleetEthStagingJson() (*asset, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
info := bindataFileInfo{name: "../config/cli/fleet-eth.staging.json", size: 1838, mode: os.FileMode(420), modTime: time.Unix(1537514234, 0)}
|
info := bindataFileInfo{name: "../config/cli/fleet-eth.staging.json", size: 1838, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||||
a := &asset{bytes: bytes, info: info}
|
a := &asset{bytes: bytes, info: info}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
@ -128,7 +128,7 @@ func ConfigCliFleetEthTestJson() (*asset, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
info := bindataFileInfo{name: "../config/cli/fleet-eth.test.json", size: 1519, mode: os.FileMode(420), modTime: time.Unix(1537514234, 0)}
|
info := bindataFileInfo{name: "../config/cli/fleet-eth.test.json", size: 1519, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||||
a := &asset{bytes: bytes, info: info}
|
a := &asset{bytes: bytes, info: info}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
@ -148,7 +148,7 @@ func ConfigCliLesEnabledJson() (*asset, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
info := bindataFileInfo{name: "../config/cli/les-enabled.json", size: 58, mode: os.FileMode(420), modTime: time.Unix(1536858252, 0)}
|
info := bindataFileInfo{name: "../config/cli/les-enabled.json", size: 58, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||||
a := &asset{bytes: bytes, info: info}
|
a := &asset{bytes: bytes, info: info}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
@ -168,7 +168,7 @@ func ConfigCliMailserverEnabledJson() (*asset, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
info := bindataFileInfo{name: "../config/cli/mailserver-enabled.json", size: 176, mode: os.FileMode(420), modTime: time.Unix(1538032850, 0)}
|
info := bindataFileInfo{name: "../config/cli/mailserver-enabled.json", size: 176, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||||
a := &asset{bytes: bytes, info: info}
|
a := &asset{bytes: bytes, info: info}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
@ -188,7 +188,7 @@ func ConfigStatusChainGenesisJson() (*asset, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
info := bindataFileInfo{name: "../config/status-chain-genesis.json", size: 612, mode: os.FileMode(420), modTime: time.Unix(1536858252, 0)}
|
info := bindataFileInfo{name: "../config/status-chain-genesis.json", size: 612, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)}
|
||||||
a := &asset{bytes: bytes, info: info}
|
a := &asset{bytes: bytes, info: info}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
3
static/migrations/1540715431_add_version.down.sql
Normal file
3
static/migrations/1540715431_add_version.down.sql
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
ALTER TABLE keys DROP COLUMN session_id;
|
||||||
|
ALTER TABLE sessions DROP COLUMN keys_count;
|
||||||
|
ALTER TABLE bundles DROP COLUMN version;
|
5
static/migrations/1540715431_add_version.up.sql
Normal file
5
static/migrations/1540715431_add_version.up.sql
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
DELETE FROM keys;
|
||||||
|
ALTER TABLE keys ADD COLUMN seq_num INTEGER NOT NULL DEFAULT 0;
|
||||||
|
ALTER TABLE keys ADD COLUMN session_id BLOB;
|
||||||
|
ALTER TABLE sessions ADD COLUMN keys_count INTEGER NOT NULL DEFAULT 0;
|
||||||
|
ALTER TABLE bundles ADD COLUMN version INTEGER NOT NULL DEFAULT 0;
|
108
vendor/github.com/status-im/doubleratchet/keys_storage.go
generated
vendored
108
vendor/github.com/status-im/doubleratchet/keys_storage.go
generated
vendored
@ -1,18 +1,26 @@
|
|||||||
package doubleratchet
|
package doubleratchet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
// KeysStorage is an interface of an abstract in-memory or persistent keys storage.
|
// KeysStorage is an interface of an abstract in-memory or persistent keys storage.
|
||||||
type KeysStorage interface {
|
type KeysStorage interface {
|
||||||
// Get returns a message key by the given key and message number.
|
// Get returns a message key by the given key and message number.
|
||||||
Get(k Key, msgNum uint) (mk Key, ok bool, err error)
|
Get(k Key, msgNum uint) (mk Key, ok bool, err error)
|
||||||
|
|
||||||
// Put saves the given mk under the specified key and msgNum.
|
// Put saves the given mk under the specified key and msgNum.
|
||||||
Put(k Key, msgNum uint, mk Key) error
|
Put(sessionID []byte, k Key, msgNum uint, mk Key, keySeqNum uint) error
|
||||||
|
|
||||||
// DeleteMk ensures there's no message key under the specified key and msgNum.
|
// DeleteMk ensures there's no message key under the specified key and msgNum.
|
||||||
DeleteMk(k Key, msgNum uint) error
|
DeleteMk(k Key, msgNum uint) error
|
||||||
|
|
||||||
// DeletePk ensures there's no message keys under the specified key.
|
// DeleteOldMKeys deletes old message keys for a session.
|
||||||
DeletePk(k Key) error
|
DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error
|
||||||
|
|
||||||
|
// TruncateMks truncates the number of keys to maxKeys.
|
||||||
|
TruncateMks(sessionID []byte, maxKeys int) error
|
||||||
|
|
||||||
// Count returns number of message keys stored under the specified key.
|
// Count returns number of message keys stored under the specified key.
|
||||||
Count(k Key) (uint, error)
|
Count(k Key) (uint, error)
|
||||||
@ -23,10 +31,10 @@ type KeysStorage interface {
|
|||||||
|
|
||||||
// KeysStorageInMemory is an in-memory message keys storage.
|
// KeysStorageInMemory is an in-memory message keys storage.
|
||||||
type KeysStorageInMemory struct {
|
type KeysStorageInMemory struct {
|
||||||
keys map[Key]map[uint]Key
|
keys map[Key]map[uint]InMemoryKey
|
||||||
}
|
}
|
||||||
|
|
||||||
// See KeysStorage.
|
// Get returns a message key by the given key and message number.
|
||||||
func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) {
|
func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) {
|
||||||
if s.keys == nil {
|
if s.keys == nil {
|
||||||
return Key{}, false, nil
|
return Key{}, false, nil
|
||||||
@ -39,22 +47,32 @@ func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return Key{}, false, nil
|
return Key{}, false, nil
|
||||||
}
|
}
|
||||||
return mk, true, nil
|
return mk.messageKey, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// See KeysStorage.
|
type InMemoryKey struct {
|
||||||
func (s *KeysStorageInMemory) Put(pubKey Key, msgNum uint, mk Key) error {
|
messageKey Key
|
||||||
|
seqNum uint
|
||||||
|
sessionID []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put saves the given mk under the specified key and msgNum.
|
||||||
|
func (s *KeysStorageInMemory) Put(sessionID []byte, pubKey Key, msgNum uint, mk Key, seqNum uint) error {
|
||||||
if s.keys == nil {
|
if s.keys == nil {
|
||||||
s.keys = make(map[Key]map[uint]Key)
|
s.keys = make(map[Key]map[uint]InMemoryKey)
|
||||||
}
|
}
|
||||||
if _, ok := s.keys[pubKey]; !ok {
|
if _, ok := s.keys[pubKey]; !ok {
|
||||||
s.keys[pubKey] = make(map[uint]Key)
|
s.keys[pubKey] = make(map[uint]InMemoryKey)
|
||||||
|
}
|
||||||
|
s.keys[pubKey][msgNum] = InMemoryKey{
|
||||||
|
sessionID: sessionID,
|
||||||
|
messageKey: mk,
|
||||||
|
seqNum: seqNum,
|
||||||
}
|
}
|
||||||
s.keys[pubKey][msgNum] = mk
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// See KeysStorage.
|
// DeleteMk ensures there's no message key under the specified key and msgNum.
|
||||||
func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error {
|
func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error {
|
||||||
if s.keys == nil {
|
if s.keys == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -72,19 +90,58 @@ func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// See KeysStorage.
|
// TruncateMks truncates the number of keys to maxKeys.
|
||||||
func (s *KeysStorageInMemory) DeletePk(pubKey Key) error {
|
func (s *KeysStorageInMemory) TruncateMks(sessionID []byte, maxKeys int) error {
|
||||||
if s.keys == nil {
|
var seqNos []uint
|
||||||
|
// Collect all seq numbers
|
||||||
|
for _, keys := range s.keys {
|
||||||
|
for _, inMemoryKey := range keys {
|
||||||
|
if bytes.Equal(inMemoryKey.sessionID, sessionID) {
|
||||||
|
seqNos = append(seqNos, inMemoryKey.seqNum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nothing to do if we haven't reached the limit
|
||||||
|
if len(seqNos) <= maxKeys {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if _, ok := s.keys[pubKey]; !ok {
|
|
||||||
return nil
|
// Take the sequence numbers we care about
|
||||||
|
sort.Slice(seqNos, func(i, j int) bool { return seqNos[i] < seqNos[j] })
|
||||||
|
toDeleteSlice := seqNos[:len(seqNos)-maxKeys]
|
||||||
|
|
||||||
|
// Put in map for easier lookup
|
||||||
|
toDelete := make(map[uint]bool)
|
||||||
|
|
||||||
|
for _, seqNo := range toDeleteSlice {
|
||||||
|
toDelete[seqNo] = true
|
||||||
}
|
}
|
||||||
delete(s.keys, pubKey)
|
|
||||||
|
for pubKey, keys := range s.keys {
|
||||||
|
for i, inMemoryKey := range keys {
|
||||||
|
if toDelete[inMemoryKey.seqNum] && bytes.Equal(inMemoryKey.sessionID, sessionID) {
|
||||||
|
delete(s.keys[pubKey], i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// See KeysStorage.
|
// DeleteOldMKeys deletes old message keys for a session.
|
||||||
|
func (s *KeysStorageInMemory) DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error {
|
||||||
|
for pubKey, keys := range s.keys {
|
||||||
|
for i, inMemoryKey := range keys {
|
||||||
|
if inMemoryKey.seqNum <= deleteUntilSeqKey && bytes.Equal(inMemoryKey.sessionID, sessionID) {
|
||||||
|
delete(s.keys[pubKey], i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns number of message keys stored under the specified key.
|
||||||
func (s *KeysStorageInMemory) Count(pubKey Key) (uint, error) {
|
func (s *KeysStorageInMemory) Count(pubKey Key) (uint, error) {
|
||||||
if s.keys == nil {
|
if s.keys == nil {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
@ -92,7 +149,16 @@ func (s *KeysStorageInMemory) Count(pubKey Key) (uint, error) {
|
|||||||
return uint(len(s.keys[pubKey])), nil
|
return uint(len(s.keys[pubKey])), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// See KeysStorage.
|
// All returns all the keys
|
||||||
func (s *KeysStorageInMemory) All() (map[Key]map[uint]Key, error) {
|
func (s *KeysStorageInMemory) All() (map[Key]map[uint]Key, error) {
|
||||||
return s.keys, nil
|
response := make(map[Key]map[uint]Key)
|
||||||
|
|
||||||
|
for pubKey, keys := range s.keys {
|
||||||
|
response[pubKey] = make(map[uint]Key)
|
||||||
|
for n, key := range keys {
|
||||||
|
response[pubKey][n] = key.messageKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
}
|
}
|
||||||
|
14
vendor/github.com/status-im/doubleratchet/options.go
generated
vendored
14
vendor/github.com/status-im/doubleratchet/options.go
generated
vendored
@ -17,7 +17,7 @@ func WithMaxSkip(n int) option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithMaxKeep specifies the maximum number of ratchet steps before a message is deleted.
|
// WithMaxKeep specifies how long we keep message keys, counted in number of messages received
|
||||||
// nolint: golint
|
// nolint: golint
|
||||||
func WithMaxKeep(n int) option {
|
func WithMaxKeep(n int) option {
|
||||||
return func(s *State) error {
|
return func(s *State) error {
|
||||||
@ -29,6 +29,18 @@ func WithMaxKeep(n int) option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithMaxMessageKeysPerSession specifies the maximum number of message keys per session
|
||||||
|
// nolint: golint
|
||||||
|
func WithMaxMessageKeysPerSession(n int) option {
|
||||||
|
return func(s *State) error {
|
||||||
|
if n < 0 {
|
||||||
|
return fmt.Errorf("n must be non-negative")
|
||||||
|
}
|
||||||
|
s.MaxMessageKeysPerSession = n
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithKeysStorage replaces the default keys storage with the specified.
|
// WithKeysStorage replaces the default keys storage with the specified.
|
||||||
// nolint: golint
|
// nolint: golint
|
||||||
func WithKeysStorage(ks KeysStorage) option {
|
func WithKeysStorage(ks KeysStorage) option {
|
||||||
|
14
vendor/github.com/status-im/doubleratchet/session.go
generated
vendored
14
vendor/github.com/status-im/doubleratchet/session.go
generated
vendored
@ -130,7 +130,6 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Is there a new ratchet key?
|
// Is there a new ratchet key?
|
||||||
isDHStepped := false
|
|
||||||
if m.Header.DH != sc.DHr {
|
if m.Header.DH != sc.DHr {
|
||||||
if skippedKeys1, err = sc.skipMessageKeys(sc.DHr, uint(m.Header.PN)); err != nil {
|
if skippedKeys1, err = sc.skipMessageKeys(sc.DHr, uint(m.Header.PN)); err != nil {
|
||||||
return nil, fmt.Errorf("can't skip previous chain message keys: %s", err)
|
return nil, fmt.Errorf("can't skip previous chain message keys: %s", err)
|
||||||
@ -138,7 +137,6 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
|
|||||||
if err = sc.dhRatchet(m.Header); err != nil {
|
if err = sc.dhRatchet(m.Header); err != nil {
|
||||||
return nil, fmt.Errorf("can't perform ratchet step: %s", err)
|
return nil, fmt.Errorf("can't perform ratchet step: %s", err)
|
||||||
}
|
}
|
||||||
isDHStepped = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// After all, update the current chain.
|
// After all, update the current chain.
|
||||||
@ -151,17 +149,13 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("can't decrypt: %s", err)
|
return nil, fmt.Errorf("can't decrypt: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply changes.
|
// Increment the number of keys
|
||||||
if err := s.applyChanges(sc, append(skippedKeys1, skippedKeys2...)); err != nil {
|
sc.KeysCount++
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if isDHStepped {
|
// Apply changes.
|
||||||
err = s.deleteSkippedKeys(s.DHr)
|
if err := s.applyChanges(sc, s.id, append(skippedKeys1, skippedKeys2...)); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Store state
|
// Store state
|
||||||
if err := s.store(); err != nil {
|
if err := s.store(); err != nil {
|
||||||
|
5
vendor/github.com/status-im/doubleratchet/session_he.go
generated
vendored
5
vendor/github.com/status-im/doubleratchet/session_he.go
generated
vendored
@ -103,12 +103,9 @@ func (s *sessionHE) RatchetDecrypt(m MessageHE, ad []byte) ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("can't decrypt: %s", err)
|
return nil, fmt.Errorf("can't decrypt: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = s.applyChanges(sc, append(skippedKeys1, skippedKeys2...)); err != nil {
|
if err = s.applyChanges(sc, []byte("FIXME"), append(skippedKeys1, skippedKeys2...)); err != nil {
|
||||||
return nil, fmt.Errorf("failed to apply changes: %s", err)
|
return nil, fmt.Errorf("failed to apply changes: %s", err)
|
||||||
}
|
}
|
||||||
if step {
|
|
||||||
_ = s.deleteSkippedKeys(s.HKr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return plaintext, nil
|
return plaintext, nil
|
||||||
}
|
}
|
||||||
|
54
vendor/github.com/status-im/doubleratchet/state.go
generated
vendored
54
vendor/github.com/status-im/doubleratchet/state.go
generated
vendored
@ -42,14 +42,18 @@ type State struct {
|
|||||||
// Sending header key and next header key. Only used for header encryption.
|
// Sending header key and next header key. Only used for header encryption.
|
||||||
HKs, NHKs Key
|
HKs, NHKs Key
|
||||||
|
|
||||||
// Number of ratchet steps after which all skipped message keys for that key will be deleted.
|
// How long we keep messages keys, counted in number of messages received,
|
||||||
|
// for example if MaxKeep is 5 we only keep the last 5 messages keys, deleting everything n - 5.
|
||||||
MaxKeep uint
|
MaxKeep uint
|
||||||
|
|
||||||
|
// Max number of message keys per session, older keys will be deleted in FIFO fashion
|
||||||
|
MaxMessageKeysPerSession int
|
||||||
|
|
||||||
// The number of the current ratchet step.
|
// The number of the current ratchet step.
|
||||||
Step uint
|
Step uint
|
||||||
|
|
||||||
// Which key for the receiving chain was used at the specified step.
|
// KeysCount the number of keys generated for decrypting
|
||||||
DeleteKeys map[uint]Key
|
KeysCount uint
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultState(sharedKey Key) State {
|
func DefaultState(sharedKey Key) State {
|
||||||
@ -65,8 +69,9 @@ func DefaultState(sharedKey Key) State {
|
|||||||
RecvCh: kdfChain{CK: sharedKey, Crypto: c},
|
RecvCh: kdfChain{CK: sharedKey, Crypto: c},
|
||||||
MkSkipped: &KeysStorageInMemory{},
|
MkSkipped: &KeysStorageInMemory{},
|
||||||
MaxSkip: 1000,
|
MaxSkip: 1000,
|
||||||
MaxKeep: 100,
|
MaxMessageKeysPerSession: 2000,
|
||||||
DeleteKeys: make(map[uint]Key),
|
MaxKeep: 2000,
|
||||||
|
KeysCount: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,6 +117,7 @@ type skippedKey struct {
|
|||||||
key Key
|
key Key
|
||||||
nr uint
|
nr uint
|
||||||
mk Key
|
mk Key
|
||||||
|
seq uint
|
||||||
}
|
}
|
||||||
|
|
||||||
// skipMessageKeys skips message keys in the current receiving chain.
|
// skipMessageKeys skips message keys in the current receiving chain.
|
||||||
@ -119,14 +125,11 @@ func (s *State) skipMessageKeys(key Key, until uint) ([]skippedKey, error) {
|
|||||||
if until < uint(s.RecvCh.N) {
|
if until < uint(s.RecvCh.N) {
|
||||||
return nil, fmt.Errorf("bad until: probably an out-of-order message that was deleted")
|
return nil, fmt.Errorf("bad until: probably an out-of-order message that was deleted")
|
||||||
}
|
}
|
||||||
nSkipped, err := s.MkSkipped.Count(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if until-uint(s.RecvCh.N)+nSkipped > s.MaxSkip {
|
if uint(s.RecvCh.N)+s.MaxSkip < until {
|
||||||
return nil, fmt.Errorf("too many messages")
|
return nil, fmt.Errorf("too many messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
skipped := []skippedKey{}
|
skipped := []skippedKey{}
|
||||||
for uint(s.RecvCh.N) < until {
|
for uint(s.RecvCh.N) < until {
|
||||||
mk := s.RecvCh.step()
|
mk := s.RecvCh.step()
|
||||||
@ -134,32 +137,31 @@ func (s *State) skipMessageKeys(key Key, until uint) ([]skippedKey, error) {
|
|||||||
key: key,
|
key: key,
|
||||||
nr: uint(s.RecvCh.N - 1),
|
nr: uint(s.RecvCh.N - 1),
|
||||||
mk: mk,
|
mk: mk,
|
||||||
|
seq: s.KeysCount,
|
||||||
})
|
})
|
||||||
|
// Increment key count
|
||||||
|
s.KeysCount++
|
||||||
|
|
||||||
}
|
}
|
||||||
return skipped, nil
|
return skipped, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *State) applyChanges(sc State, skipped []skippedKey) error {
|
func (s *State) applyChanges(sc State, sessionID []byte, skipped []skippedKey) error {
|
||||||
*s = sc
|
*s = sc
|
||||||
for _, skipped := range skipped {
|
for _, skipped := range skipped {
|
||||||
if err := s.MkSkipped.Put(skipped.key, skipped.nr, skipped.mk); err != nil {
|
if err := s.MkSkipped.Put(sessionID, skipped.key, skipped.nr, skipped.mk, skipped.seq); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.MkSkipped.TruncateMks(sessionID, s.MaxMessageKeysPerSession); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if s.KeysCount >= s.MaxKeep {
|
||||||
|
if err := s.MkSkipped.DeleteOldMks(sessionID, s.KeysCount-s.MaxKeep); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *State) deleteSkippedKeys(key Key) error {
|
|
||||||
|
|
||||||
s.DeleteKeys[s.Step] = key
|
|
||||||
s.Step++
|
|
||||||
if hk, ok := s.DeleteKeys[s.Step-s.MaxKeep]; ok {
|
|
||||||
if err := s.MkSkipped.DeletePk(hk); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(s.DeleteKeys, s.Step-s.MaxKeep)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user