From ee3c05c79b2e9670efb0b27a9996ca9017e476f6 Mon Sep 17 00:00:00 2001 From: Andrea Maria Piana Date: Mon, 5 Nov 2018 20:00:04 +0100 Subject: [PATCH] Change handling of skipped/deleted keys & add version (#1261) - Skipped keys The purpose of limiting the number of skipped keys generated is to avoid a dos attack whereby an attacker would send a large N, forcing the device to compute all the keys between currentN..N . Previously the logic for handling skipped keys was: - If in the current receiving chain there are more than maxSkip keys, throw an error This is problematic as in long-lived session dropped/unreceived messages starts piling up, eventually reaching the threshold (1000 dropped/unreceived messages). This logic has been changed to be more inline with signals spec, and now it is: - If N is > currentN + maxSkip, throw an error The purpose of limiting the number of skipped keys stored is to avoid a dos attack whereby an attacker would force us to store a large number of keys, filling up our storage. Previously the logic for handling old keys was: - Once you have maxKeep ratchet steps, delete any key from currentRatchet - maxKeep. This, in combination with the maxSkip implementation, capped the number of stored keys to maxSkip * maxKeep. The logic has been changed to: - Keep a maximum of MaxMessageKeysPerSession and additionally we delete any key that has a sequence number < currentSeqNum - maxKeep - Version We check now the version of the bundle so that when we get a bundle from the same installationID with a higher version, we mark the previous bundle as expired and use the new bundle the next time a message is sent --- Gopkg.lock | 4 +- Gopkg.toml | 2 +- services/shhext/chat/encryption.go | 33 ++- services/shhext/chat/encryption_test.go | 210 ++++++++++++++++++ services/shhext/chat/migrations/bindata.go | 56 ++++- services/shhext/chat/protocol.go | 3 +- services/shhext/chat/protocol_test.go | 2 +- services/shhext/chat/sql_lite_persistence.go | 118 +++++++--- .../sql_lite_persistence_keys_storage_test.go | 115 +++++++--- .../shhext/chat/sql_lite_persistence_test.go | 83 ++++++- static/bindata.go | 12 +- .../1540715431_add_version.down.sql | 3 + .../migrations/1540715431_add_version.up.sql | 5 + .../status-im/doubleratchet/keys_storage.go | 108 +++++++-- .../status-im/doubleratchet/options.go | 14 +- .../status-im/doubleratchet/session.go | 16 +- .../status-im/doubleratchet/session_he.go | 5 +- .../status-im/doubleratchet/state.go | 62 +++--- 18 files changed, 690 insertions(+), 161 deletions(-) create mode 100644 static/migrations/1540715431_add_version.down.sql create mode 100644 static/migrations/1540715431_add_version.up.sql diff --git a/Gopkg.lock b/Gopkg.lock index e9c82ceff..7a4523891 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -782,11 +782,11 @@ version = "v1.1" [[projects]] - digest = "1:b80bfc6278c27edcb1946fc38bc61f4c3a1cd50978083a4a79b7c875c70ed9d8" + digest = "1:41fb72d7a71f37f1f9c766d965178636ecda21b429b1f2e3fff42cfc31279751" name = "github.com/status-im/doubleratchet" packages = ["."] pruneopts = "NUT" - revision = "321788dbb6eac36f7dab04e631db139e13bb280b" + revision = "4dcb6cba284ae9f97129e2a98b9277f629d9dbc4" [[projects]] branch = "master" diff --git a/Gopkg.toml b/Gopkg.toml index 2a62a1c2c..ca65e677c 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -160,7 +160,7 @@ [[constraint]] name = "github.com/status-im/doubleratchet" - revision = "321788dbb6eac36f7dab04e631db139e13bb280b" + revision = "4dcb6cba284ae9f97129e2a98b9277f629d9dbc4" [[constraint]] name = "github.com/status-im/migrate" diff --git a/services/shhext/chat/encryption.go b/services/shhext/chat/encryption.go index 69882956d..bf3543769 100644 --- a/services/shhext/chat/encryption.go +++ b/services/shhext/chat/encryption.go @@ -21,6 +21,15 @@ var ErrSessionNotFound = errors.New("session not found") // If we have no bundles, we use a constant so that the message can reach any device const noInstallationID = "none" +// How many consecutive messages can be skipped in the receiving chain +const maxSkip = 1000 + +// Any message with seqNo <= currentSeq - maxKeep will be deleted +const maxKeep = 3000 + +// How many keys do we store in total per session +const maxMessageKeysPerSession = 2000 + // EncryptionService defines a service that is responsible for the encryption aspect of the protocol type EncryptionService struct { log log.Logger @@ -251,9 +260,9 @@ func (s *EncryptionService) createNewSession(drInfo *RatchetInfo, sk [32]byte, k keyPair, s.persistence.GetSessionStorage(), dr.WithKeysStorage(s.persistence.GetKeysStorage()), - // TODO: Temporarily increase to a high number, until - // we make sure it's a sliding window rather than dropping - dr.WithMaxSkip(10000), + dr.WithMaxSkip(maxSkip), + dr.WithMaxKeep(maxKeep), + dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession), dr.WithCrypto(crypto.EthereumCrypto{})) } else { session, err = dr.NewWithRemoteKey( @@ -262,9 +271,9 @@ func (s *EncryptionService) createNewSession(drInfo *RatchetInfo, sk [32]byte, k keyPair.PubKey, s.persistence.GetSessionStorage(), dr.WithKeysStorage(s.persistence.GetKeysStorage()), - // TODO: Temporarily increase to a high number, until - // we make sure it's a sliding window rather than dropping - dr.WithMaxSkip(10000), + dr.WithMaxSkip(maxSkip), + dr.WithMaxKeep(maxKeep), + dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession), dr.WithCrypto(crypto.EthereumCrypto{})) } @@ -291,9 +300,9 @@ func (s *EncryptionService) encryptUsingDR(theirIdentityKey *ecdsa.PublicKey, dr drInfo.ID, sessionStorage, dr.WithKeysStorage(s.persistence.GetKeysStorage()), - // TODO: Temporarily increase to a high number, until - // we make sure it's a sliding window rather than dropping - dr.WithMaxSkip(10000), + dr.WithMaxSkip(maxSkip), + dr.WithMaxKeep(maxKeep), + dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession), dr.WithCrypto(crypto.EthereumCrypto{}), ) if err != nil { @@ -342,9 +351,9 @@ func (s *EncryptionService) decryptUsingDR(theirIdentityKey *ecdsa.PublicKey, dr drInfo.ID, sessionStorage, dr.WithKeysStorage(s.persistence.GetKeysStorage()), - // TODO: Temporarily increase to a high number, until - // we make sure it's a sliding window rather than dropping - dr.WithMaxSkip(10000), + dr.WithMaxSkip(maxSkip), + dr.WithMaxKeep(maxKeep), + dr.WithMaxMessageKeysPerSession(maxMessageKeysPerSession), dr.WithCrypto(crypto.EthereumCrypto{}), ) if err != nil { diff --git a/services/shhext/chat/encryption_test.go b/services/shhext/chat/encryption_test.go index 0dce73675..de28c3ac7 100644 --- a/services/shhext/chat/encryption_test.go +++ b/services/shhext/chat/encryption_test.go @@ -314,6 +314,213 @@ func (s *EncryptionServiceTestSuite) TestConversation() { s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH") } +// Previous implementation allowed max maxSkip keys in the same receiving chain +// leading to a problem whereby dropped messages would accumulate and eventually +// we would not be able to decrypt any new message anymore. +// Here we are testing that maxSkip only applies to *consecutive* messages, not +// overall. +func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() { + bobText := []byte("text") + + bobKey, err := crypto.GenerateKey() + s.Require().NoError(err) + + aliceKey, err := crypto.GenerateKey() + s.Require().NoError(err) + + // Create a bundle + bobBundle, err := s.bob.CreateBundle(bobKey) + s.Require().NoError(err) + + // We add bob bundle + _, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle) + s.Require().NoError(err) + + // Create a bundle + aliceBundle, err := s.alice.CreateBundle(aliceKey) + s.Require().NoError(err) + + // We add alice bundle + _, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle) + s.Require().NoError(err) + + // Bob sends a message + + for i := 0; i < maxSkip; i++ { + _, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + s.Require().NoError(err) + } + + // Bob sends a message + bobMessage1, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + s.Require().NoError(err) + + // Alice receives the message + _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1) + s.Require().NoError(err) + + // Bob sends a message + _, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + s.Require().NoError(err) + + // Bob sends a message + bobMessage2, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + s.Require().NoError(err) + + // Alice receives the message, we should have maxSkip + 1 keys in the db, but + // we should not throw an error + _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2) + s.Require().NoError(err) +} + +// Test that an error is thrown if max skip is reached +func (s *EncryptionServiceTestSuite) TestMaxSkipKeysError() { + bobText := []byte("text") + + bobKey, err := crypto.GenerateKey() + s.Require().NoError(err) + + aliceKey, err := crypto.GenerateKey() + s.Require().NoError(err) + + // Create a bundle + bobBundle, err := s.bob.CreateBundle(bobKey) + s.Require().NoError(err) + + // We add bob bundle + _, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle) + s.Require().NoError(err) + + // Create a bundle + aliceBundle, err := s.alice.CreateBundle(aliceKey) + s.Require().NoError(err) + + // We add alice bundle + _, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle) + s.Require().NoError(err) + + // Bob sends a message + + for i := 0; i < maxSkip+1; i++ { + _, err = s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + s.Require().NoError(err) + } + + // Bob sends a message + bobMessage1, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + s.Require().NoError(err) + + // Alice receives the message + _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1) + s.Require().Equal(errors.New("can't skip current chain message keys: too many messages"), err) +} + +func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() { + bobText := []byte("text") + + bobKey, err := crypto.GenerateKey() + s.Require().NoError(err) + + aliceKey, err := crypto.GenerateKey() + s.Require().NoError(err) + + // Create a bundle + bobBundle, err := s.bob.CreateBundle(bobKey) + s.Require().NoError(err) + + // We add bob bundle + _, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle) + s.Require().NoError(err) + + // Create a bundle + aliceBundle, err := s.alice.CreateBundle(aliceKey) + s.Require().NoError(err) + + // We add alice bundle + _, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle) + s.Require().NoError(err) + + // We create just enough messages so that the first key should be deleted + + nMessages := maxMessageKeysPerSession + maxMessageKeysPerSession/maxSkip + 2 + messages := make([]map[string]*DirectMessageProtocol, nMessages) + for i := 0; i < nMessages; i++ { + m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + s.Require().NoError(err) + + messages[i] = m + + // We decrypt some messages otherwise we hit maxSkip limit + if i%maxSkip == 0 { + _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m) + s.Require().NoError(err) + } + + } + + // Another message to trigger the deletion + m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + s.Require().NoError(err) + _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m) + s.Require().NoError(err) + + // We decrypt the first message, and it should fail + _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1]) + s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err) + + // We decrypt the second message, and it should be decrypted + _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[2]) + s.Require().NoError(err) +} + +func (s *EncryptionServiceTestSuite) TestMaxKeep() { + bobText := []byte("text") + + bobKey, err := crypto.GenerateKey() + s.Require().NoError(err) + + aliceKey, err := crypto.GenerateKey() + s.Require().NoError(err) + + // Create a bundle + bobBundle, err := s.bob.CreateBundle(bobKey) + s.Require().NoError(err) + + // We add bob bundle + _, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle) + s.Require().NoError(err) + + // Create a bundle + aliceBundle, err := s.alice.CreateBundle(aliceKey) + s.Require().NoError(err) + + // We add alice bundle + _, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle) + s.Require().NoError(err) + + // We decrypt all messages but 1 & 2 + messages := make([]map[string]*DirectMessageProtocol, maxKeep) + for i := 0; i < maxKeep; i++ { + m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText) + messages[i] = m + s.Require().NoError(err) + + if i != 0 && i != 1 { + _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m) + s.Require().NoError(err) + } + + } + + // We decrypt the first message, and it should fail, as it should have been removed + _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[0]) + s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err) + + // We decrypt the second message, and it should be decrypted + _, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1]) + s.Require().NoError(err) +} + // Alice has Bob's bundle // Bob has Alice's bundle // Bob sends a message to alice @@ -557,6 +764,9 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() { bobBundle2, err := NewBundleContainer(bobKey, bobInstallationID) s.Require().NoError(err) + // We set the version + + bobBundle2.GetBundle().GetSignedPreKeys()[bobInstallationID].Version = 1 err = SignBundle(bobKey, bobBundle2) s.Require().NoError(err) diff --git a/services/shhext/chat/migrations/bindata.go b/services/shhext/chat/migrations/bindata.go index 0d868f2b5..c82ec4296 100644 --- a/services/shhext/chat/migrations/bindata.go +++ b/services/shhext/chat/migrations/bindata.go @@ -4,6 +4,8 @@ // 1536754952_initial_schema.up.sql // 1539249977_update_ratchet_info.down.sql // 1539249977_update_ratchet_info.up.sql +// 1540715431_add_version.down.sql +// 1540715431_add_version.up.sql // static.go // DO NOT EDIT! @@ -87,7 +89,7 @@ func _1536754952_initial_schemaDownSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1536754952_initial_schema.down.sql", size: 83, mode: os.FileMode(420), modTime: time.Unix(1537862328, 0)} + info := bindataFileInfo{name: "1536754952_initial_schema.down.sql", size: 83, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -107,7 +109,7 @@ func _1536754952_initial_schemaUpSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1536754952_initial_schema.up.sql", size: 962, mode: os.FileMode(420), modTime: time.Unix(1539252806, 0)} + info := bindataFileInfo{name: "1536754952_initial_schema.up.sql", size: 962, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -127,7 +129,7 @@ func _1539249977_update_ratchet_infoDownSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1539249977_update_ratchet_info.down.sql", size: 311, mode: os.FileMode(420), modTime: time.Unix(1539250187, 0)} + info := bindataFileInfo{name: "1539249977_update_ratchet_info.down.sql", size: 311, mode: os.FileMode(420), modTime: time.Unix(1540738831, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -147,7 +149,47 @@ func _1539249977_update_ratchet_infoUpSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1539249977_update_ratchet_info.up.sql", size: 368, mode: os.FileMode(420), modTime: time.Unix(1539250201, 0)} + info := bindataFileInfo{name: "1539249977_update_ratchet_info.up.sql", size: 368, mode: os.FileMode(420), modTime: time.Unix(1540738831, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +var __1540715431_add_versionDownSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x72\xf4\x09\x71\x0d\x52\x08\x71\x74\xf2\x71\x55\xc8\x4e\xad\x2c\x56\x70\x09\xf2\x0f\x50\x70\xf6\xf7\x09\xf5\xf5\x53\x28\x4e\x2d\x2e\xce\xcc\xcf\x8b\xcf\x4c\xb1\xe6\x42\x56\x08\x15\x47\x55\x0c\xd2\x1d\x9f\x9c\x5f\x9a\x57\x82\xaa\x38\xa9\x34\x2f\x25\x27\x15\x55\x6d\x59\x6a\x11\xc8\x00\x6b\x2e\x40\x00\x00\x00\xff\xff\xda\x5d\x80\x2d\x7f\x00\x00\x00") + +func _1540715431_add_versionDownSqlBytes() ([]byte, error) { + return bindataRead( + __1540715431_add_versionDownSql, + "1540715431_add_version.down.sql", + ) +} + +func _1540715431_add_versionDownSql() (*asset, error) { + bytes, err := _1540715431_add_versionDownSqlBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "1540715431_add_version.down.sql", size: 127, mode: os.FileMode(420), modTime: time.Unix(1540989119, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +var __1540715431_add_versionUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x8c\xcd\xb1\x0e\x02\x21\x0c\xc6\xf1\xdd\xa7\xf8\x1e\xc1\xdd\x09\xa4\x67\x4c\x7a\x90\x90\x32\x93\xe8\x31\x5c\x54\x2e\x8a\x98\xf8\xf6\x06\xe3\xc2\xa2\xae\x6d\xff\xbf\x1a\x62\x12\xc2\xe0\xdd\x88\x53\x7a\x96\xcd\x4a\xb1\x90\x87\x28\xcd\xf4\x9e\x40\x19\x83\xad\xe3\x30\x5a\x94\x74\x8d\xb9\x5e\xb0\xb7\x42\x3b\xf2\xb0\x4e\x60\x03\x33\x0c\x0d\x2a\xb0\x60\xfd\xab\x2f\x65\x5e\x72\x9c\x27\x68\x76\xba\x3f\xfe\x2c\xbb\xa0\x01\xf1\xb8\xd4\x7c\xff\xfb\xe7\xa1\xe6\xe9\x9c\x3a\xe5\x91\x6e\x4d\xfe\x4a\xbc\x02\x00\x00\xff\xff\x0e\x27\x2c\x52\x09\x01\x00\x00") + +func _1540715431_add_versionUpSqlBytes() ([]byte, error) { + return bindataRead( + __1540715431_add_versionUpSql, + "1540715431_add_version.up.sql", + ) +} + +func _1540715431_add_versionUpSql() (*asset, error) { + bytes, err := _1540715431_add_versionUpSqlBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "1540715431_add_version.up.sql", size: 265, mode: os.FileMode(420), modTime: time.Unix(1540989075, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -167,7 +209,7 @@ func staticGo() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "static.go", size: 188, mode: os.FileMode(420), modTime: time.Unix(1537862328, 0)} + info := bindataFileInfo{name: "static.go", size: 188, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -228,6 +270,8 @@ var _bindata = map[string]func() (*asset, error){ "1536754952_initial_schema.up.sql": _1536754952_initial_schemaUpSql, "1539249977_update_ratchet_info.down.sql": _1539249977_update_ratchet_infoDownSql, "1539249977_update_ratchet_info.up.sql": _1539249977_update_ratchet_infoUpSql, + "1540715431_add_version.down.sql": _1540715431_add_versionDownSql, + "1540715431_add_version.up.sql": _1540715431_add_versionUpSql, "static.go": staticGo, } @@ -275,6 +319,8 @@ var _bintree = &bintree{nil, map[string]*bintree{ "1536754952_initial_schema.up.sql": &bintree{_1536754952_initial_schemaUpSql, map[string]*bintree{}}, "1539249977_update_ratchet_info.down.sql": &bintree{_1539249977_update_ratchet_infoDownSql, map[string]*bintree{}}, "1539249977_update_ratchet_info.up.sql": &bintree{_1539249977_update_ratchet_infoUpSql, map[string]*bintree{}}, + "1540715431_add_version.down.sql": &bintree{_1540715431_add_versionDownSql, map[string]*bintree{}}, + "1540715431_add_version.up.sql": &bintree{_1540715431_add_versionUpSql, map[string]*bintree{}}, "static.go": &bintree{staticGo, map[string]*bintree{}}, }} diff --git a/services/shhext/chat/protocol.go b/services/shhext/chat/protocol.go index b53616da5..0d9659061 100644 --- a/services/shhext/chat/protocol.go +++ b/services/shhext/chat/protocol.go @@ -61,8 +61,7 @@ func (p *ProtocolService) BuildPublicMessage(myIdentityKey *ecdsa.PrivateKey, pa // BuildDirectMessage marshals a 1:1 chat message given the user identity private key, the recipient's public key, and a payload func (p *ProtocolService) BuildDirectMessage(myIdentityKey *ecdsa.PrivateKey, payload []byte, theirPublicKeys ...*ecdsa.PublicKey) (map[*ecdsa.PublicKey][]byte, error) { response := make(map[*ecdsa.PublicKey][]byte) - publicKeys := append(theirPublicKeys, &myIdentityKey.PublicKey) - for _, publicKey := range publicKeys { + for _, publicKey := range theirPublicKeys { // Encrypt payload encryptionResponse, err := p.encryption.EncryptPayload(publicKey, myIdentityKey, payload) if err != nil { diff --git a/services/shhext/chat/protocol_test.go b/services/shhext/chat/protocol_test.go index ae01b9e42..fed946234 100644 --- a/services/shhext/chat/protocol_test.go +++ b/services/shhext/chat/protocol_test.go @@ -56,7 +56,7 @@ func (s *ProtocolServiceTestSuite) TestBuildDirectMessage() { }) s.NoError(err) - marshaledMsg, err := s.alice.BuildDirectMessage(aliceKey, payload, &bobKey.PublicKey) + marshaledMsg, err := s.alice.BuildDirectMessage(aliceKey, payload, &bobKey.PublicKey, &aliceKey.PublicKey) s.NoError(err) s.NotNil(marshaledMsg, "It creates a message") diff --git a/services/shhext/chat/sql_lite_persistence.go b/services/shhext/chat/sql_lite_persistence.go index 9ace508b7..8b69b01f0 100644 --- a/services/shhext/chat/sql_lite_persistence.go +++ b/services/shhext/chat/sql_lite_persistence.go @@ -16,6 +16,9 @@ import ( "github.com/status-im/status-go/services/shhext/chat/migrations" ) +// A safe max number of rows +const maxNumberOfRows = 100000000 + // SQLLitePersistence represents a persistence service tied to an SQLite database type SQLLitePersistence struct { db *sql.DB @@ -107,7 +110,20 @@ func (s *SQLLitePersistence) AddPrivateBundle(b *BundleContainer) error { } for installationID, signedPreKey := range b.GetBundle().GetSignedPreKeys() { - stmt, err := tx.Prepare("INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, timestamp) VALUES(?, ?, ?, ?, ?)") + var version uint32 + stmt, err := tx.Prepare("SELECT version FROM bundles WHERE installation_id = ? AND identity = ? ORDER BY version DESC LIMIT 1") + if err != nil { + return err + } + + defer stmt.Close() + + err = stmt.QueryRow(installationID, b.GetBundle().GetIdentity()).Scan(&version) + if err != nil && err != sql.ErrNoRows { + return err + } + + stmt, err = tx.Prepare("INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, version, timestamp) VALUES(?, ?, ?, ?, ?, ?)") if err != nil { return err } @@ -118,6 +134,7 @@ func (s *SQLLitePersistence) AddPrivateBundle(b *BundleContainer) error { b.GetPrivateSignedPreKey(), signedPreKey.GetSignedPreKey(), installationID, + version+1, time.Now().UnixNano(), ) if err != nil { @@ -144,15 +161,18 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error { for installationID, signedPreKeyContainer := range b.GetSignedPreKeys() { signedPreKey := signedPreKeyContainer.GetSignedPreKey() - insertStmt, err := tx.Prepare("INSERT INTO bundles(identity, signed_pre_key, installation_id, timestamp) VALUES( ?, ?, ?, ?)") + version := signedPreKeyContainer.GetVersion() + insertStmt, err := tx.Prepare("INSERT INTO bundles(identity, signed_pre_key, installation_id, version, timestamp) VALUES( ?, ?, ?, ?, ?)") if err != nil { return err } defer insertStmt.Close() + _, err = insertStmt.Exec( b.GetIdentity(), signedPreKey, installationID, + version, time.Now().UnixNano(), ) if err != nil { @@ -160,7 +180,7 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error { return err } // Mark old bundles as expired - updateStmt, err := tx.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND installation_id = ? AND signed_pre_key != ?") + updateStmt, err := tx.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND installation_id = ? AND version < ?") if err != nil { return err } @@ -169,7 +189,7 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error { _, err = updateStmt.Exec( b.GetIdentity(), installationID, - signedPreKey, + version, ) if err != nil { _ = tx.Rollback() @@ -281,7 +301,7 @@ func (s *SQLLitePersistence) MarkBundleExpired(identity []byte) error { func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey) (*Bundle, error) { identity := crypto.CompressPubkey(publicKey) - stmt, err := s.db.Prepare("SELECT signed_pre_key,installation_id FROM bundles WHERE expired = 0 AND identity = ? ORDER BY timestamp DESC") + stmt, err := s.db.Prepare("SELECT signed_pre_key,installation_id, version FROM bundles WHERE expired = 0 AND identity = ? ORDER BY version DESC") if err != nil { return nil, err } @@ -304,16 +324,21 @@ func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey) (*Bundl for rows.Next() { var signedPreKey []byte var installationID string + var version uint32 rowCount++ err = rows.Scan( &signedPreKey, &installationID, + &version, ) if err != nil { return nil, err } - bundle.SignedPreKeys[installationID] = &SignedPreKey{SignedPreKey: signedPreKey} + bundle.SignedPreKeys[installationID] = &SignedPreKey{ + SignedPreKey: signedPreKey, + Version: version, + } } @@ -448,17 +473,53 @@ func (s *SQLLiteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, erro } // Put stores a key with the specified public key, message number and message key -func (s *SQLLiteKeysStorage) Put(pubKey dr.Key, msgNum uint, mk dr.Key) error { - stmt, err := s.db.Prepare("insert into keys(public_key, msg_num, message_key) values(?, ?, ?)") +func (s *SQLLiteKeysStorage) Put(sessionID []byte, pubKey dr.Key, msgNum uint, mk dr.Key, seqNum uint) error { + stmt, err := s.db.Prepare("insert into keys(session_id, public_key, msg_num, message_key, seq_num) values(?, ?, ?, ?, ?)") if err != nil { return err } defer stmt.Close() _, err = stmt.Exec( + sessionID, pubKey[:], msgNum, mk[:], + seqNum, + ) + + return err +} + +// DeleteOldMks caps remove any key < seq_num, included +func (s *SQLLiteKeysStorage) DeleteOldMks(sessionID []byte, deleteUntil uint) error { + stmt, err := s.db.Prepare("DELETE FROM keys WHERE session_id = ? AND seq_num <= ?") + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec( + sessionID, + deleteUntil, + ) + + return err +} + +// TruncateMks caps the number of keys to maxKeysPerSession deleting them in FIFO fashion +func (s *SQLLiteKeysStorage) TruncateMks(sessionID []byte, maxKeysPerSession int) error { + stmt, err := s.db.Prepare("DELETE FROM keys WHERE rowid IN (SELECT rowid FROM keys WHERE session_id = ? ORDER BY seq_num DESC LIMIT ? OFFSET ?)") + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec( + sessionID, + // We LIMIT to the max number of rows here, as OFFSET can't be used without a LIMIT + maxNumberOfRows, + maxKeysPerSession, ) return err @@ -480,21 +541,6 @@ func (s *SQLLiteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error { return err } -// DeletePk deletes the keys with the specified public key -func (s *SQLLiteKeysStorage) DeletePk(pubKey dr.Key) error { - stmt, err := s.db.Prepare("DELETE FROM keys WHERE public_key = ?") - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec( - pubKey[:], - ) - - return err -} - // Count returns the count of keys with the specified public key func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) { stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys WHERE public_key = ?") @@ -512,6 +558,23 @@ func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) { return count, nil } +// CountAll returns the count of keys with the specified public key +func (s *SQLLiteKeysStorage) CountAll() (uint, error) { + stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys") + if err != nil { + return 0, err + } + defer stmt.Close() + + var count uint + err = stmt.QueryRow().Scan(&count) + if err != nil { + return 0, err + } + + return count, nil +} + // All returns nil func (s *SQLLiteKeysStorage) All() (map[dr.Key]map[uint]dr.Key, error) { return nil, nil @@ -525,6 +588,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error { dhsPrivate := dhs.PrivateKey() pn := state.PN step := state.Step + keysCount := state.KeysCount rootChainKey := state.RootCh.CK[:] @@ -534,7 +598,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error { recvChainKey := state.RecvCh.CK[:] recvChainN := state.RecvCh.N - stmt, err := s.db.Prepare("insert into sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)") + stmt, err := s.db.Prepare("insert into sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)") if err != nil { return err } @@ -552,6 +616,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error { recvChainN, pn, step, + keysCount, ) return err @@ -559,7 +624,7 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error { // Load retrieves the double ratchet state for a given ID func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) { - stmt, err := s.db.Prepare("SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step FROM sessions WHERE id = ?") + stmt, err := s.db.Prepare("SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count FROM sessions WHERE id = ?") if err != nil { return nil, err } @@ -577,6 +642,7 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) { recvChainN uint pn uint step uint + keysCount uint ) err = stmt.QueryRow(id).Scan( @@ -590,6 +656,7 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) { &recvChainN, &pn, &step, + &keysCount, ) switch err { case sql.ErrNoRows: @@ -599,6 +666,7 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) { state.PN = uint32(pn) state.Step = step + state.KeysCount = keysCount state.DHs = ecrypto.DHPair{ PrvKey: toKey(dhsPrivate), diff --git a/services/shhext/chat/sql_lite_persistence_keys_storage_test.go b/services/shhext/chat/sql_lite_persistence_keys_storage_test.go index 47939dc59..a37b9cf7d 100644 --- a/services/shhext/chat/sql_lite_persistence_keys_storage_test.go +++ b/services/shhext/chat/sql_lite_persistence_keys_storage_test.go @@ -10,8 +10,11 @@ import ( var ( pubKey1 = dr.Key{0xe3, 0xbe, 0xb9, 0x4e, 0x70, 0x17, 0x37, 0xc, 0x1, 0x8f, 0xa9, 0x7e, 0xef, 0x4, 0xfb, 0x23, 0xac, 0xea, 0x28, 0xf7, 0xa9, 0x56, 0xcc, 0x1d, 0x46, 0xf3, 0xb5, 0x1d, 0x7d, 0x7d, 0x5e, 0x2c} pubKey2 = dr.Key{0xec, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40} - mk1 = dr.Key{0xeb, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40} - mk2 = dr.Key{0xed, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40} + mk1 = dr.Key{0x00, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40} + mk2 = dr.Key{0x01, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40} + mk3 = dr.Key{0x02, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40} + mk4 = dr.Key{0x03, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40} + mk5 = dr.Key{0x04, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40} ) func TestSQLLitePersistenceKeysStorageTestSuite(t *testing.T) { @@ -47,10 +50,84 @@ func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLiteGetMissin func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Put() { // Act and assert. - err := s.service.Put(pubKey1, 0, mk1) + err := s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1) s.NoError(err) } +func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_DeleteOldMks() { + // Insert keys out-of-order + err := s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1) + s.NoError(err) + err = s.service.Put([]byte("session-id"), pubKey1, 1, mk2, 2) + s.NoError(err) + err = s.service.Put([]byte("session-id"), pubKey1, 2, mk3, 20) + s.NoError(err) + err = s.service.Put([]byte("session-id"), pubKey1, 3, mk4, 21) + s.NoError(err) + err = s.service.Put([]byte("session-id"), pubKey1, 4, mk5, 22) + s.NoError(err) + + err = s.service.DeleteOldMks([]byte("session-id"), 20) + s.NoError(err) + + _, ok, err := s.service.Get(pubKey1, 0) + s.NoError(err) + s.False(ok) + + _, ok, err = s.service.Get(pubKey1, 1) + s.NoError(err) + s.False(ok) + + _, ok, err = s.service.Get(pubKey1, 2) + s.NoError(err) + s.False(ok) + + _, ok, err = s.service.Get(pubKey1, 3) + s.NoError(err) + s.True(ok) + + _, ok, err = s.service.Get(pubKey1, 4) + s.NoError(err) + s.True(ok) +} + +func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_TruncateMks() { + // Insert keys out-of-order + err := s.service.Put([]byte("session-id"), pubKey2, 2, mk5, 5) + s.NoError(err) + err = s.service.Put([]byte("session-id"), pubKey2, 0, mk3, 3) + s.NoError(err) + err = s.service.Put([]byte("session-id"), pubKey1, 1, mk2, 2) + s.NoError(err) + err = s.service.Put([]byte("session-id"), pubKey2, 1, mk4, 4) + s.NoError(err) + err = s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1) + s.NoError(err) + + err = s.service.TruncateMks([]byte("session-id"), 2) + s.NoError(err) + + _, ok, err := s.service.Get(pubKey1, 0) + s.NoError(err) + s.False(ok) + + _, ok, err = s.service.Get(pubKey1, 1) + s.NoError(err) + s.False(ok) + + _, ok, err = s.service.Get(pubKey2, 0) + s.NoError(err) + s.False(ok) + + _, ok, err = s.service.Get(pubKey2, 1) + s.NoError(err) + s.True(ok) + + _, ok, err = s.service.Get(pubKey2, 2) + s.NoError(err) + s.True(ok) +} + func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Count() { // Act. @@ -71,12 +148,8 @@ func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Delete() func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Flow() { - // Act and assert. - err := s.service.DeletePk(pubKey1) - s.NoError(err) - // Act. - err = s.service.Put(pubKey1, 0, mk1) + err := s.service.Put([]byte("session-id"), pubKey1, 0, mk1, 1) s.NoError(err) k, ok, err := s.service.Get(pubKey1, 0) @@ -124,30 +197,4 @@ func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLite_Flow() { // Assert. s.NoError(err) s.EqualValues(0, cnt) - - // Act. - err = s.service.Put(pubKey1, 0, mk1) - s.NoError(err) - - err = s.service.Put(pubKey2, 0, mk2) - s.NoError(err) - - err = s.service.DeletePk(pubKey1) - s.NoError(err) - - err = s.service.DeletePk(pubKey1) - s.NoError(err) - - err = s.service.DeletePk(pubKey2) - s.NoError(err) - - cn1, err := s.service.Count(pubKey1) - s.NoError(err) - - cn2, err := s.service.Count(pubKey2) - s.NoError(err) - - // Assert. - s.Empty(cn1) - s.Empty(cn2) } diff --git a/services/shhext/chat/sql_lite_persistence_test.go b/services/shhext/chat/sql_lite_persistence_test.go index 3c89e3e67..c404f2b30 100644 --- a/services/shhext/chat/sql_lite_persistence_test.go +++ b/services/shhext/chat/sql_lite_persistence_test.go @@ -51,7 +51,7 @@ func (s *SQLLitePersistenceTestSuite) TestPrivateBundle() { s.Require().NoError(err) actualKey, err := s.service.GetPrivateKeyBundle([]byte("non-existing")) - s.Require().NoError(err, "It does not return an error if the bundle is not there") + s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Nil(actualKey) anyPrivateBundle, err := s.service.GetAnyPrivateBundle([]byte("non-existing-id")) @@ -82,7 +82,7 @@ func (s *SQLLitePersistenceTestSuite) TestPublicBundle() { s.Require().NoError(err) actualBundle, err := s.service.GetPublicBundle(&key.PublicKey) - s.Require().NoError(err, "It does not return an error if the bundle is not there") + s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Nil(actualBundle) bundleContainer, err := NewBundleContainer(key, "1") @@ -98,12 +98,82 @@ func (s *SQLLitePersistenceTestSuite) TestPublicBundle() { s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys") } +func (s *SQLLitePersistenceTestSuite) TestUpdatedBundle() { + key, err := crypto.GenerateKey() + s.Require().NoError(err) + + actualBundle, err := s.service.GetPublicBundle(&key.PublicKey) + s.Require().NoError(err, "Error was not returned even though bundle is not there") + s.Nil(actualBundle) + + // Create & add initial bundle + bundleContainer, err := NewBundleContainer(key, "1") + s.Require().NoError(err) + + bundle := bundleContainer.GetBundle() + err = s.service.AddPublicBundle(bundle) + s.Require().NoError(err) + + // Create & add a new bundle + bundleContainer, err = NewBundleContainer(key, "1") + s.Require().NoError(err) + bundle = bundleContainer.GetBundle() + // We set the version + bundle.GetSignedPreKeys()["1"].Version = 1 + + err = s.service.AddPublicBundle(bundle) + s.Require().NoError(err) + + actualBundle, err = s.service.GetPublicBundle(&key.PublicKey) + s.Require().NoError(err) + s.Equal(bundle.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity") + s.Equal(bundle.GetSignedPreKeys(), actualBundle.GetSignedPreKeys(), "It sets the right prekeys") +} + +func (s *SQLLitePersistenceTestSuite) TestOutOfOrderBundles() { + key, err := crypto.GenerateKey() + s.Require().NoError(err) + + actualBundle, err := s.service.GetPublicBundle(&key.PublicKey) + s.Require().NoError(err, "Error was not returned even though bundle is not there") + s.Nil(actualBundle) + + // Create & add initial bundle + bundleContainer, err := NewBundleContainer(key, "1") + s.Require().NoError(err) + + bundle1 := bundleContainer.GetBundle() + err = s.service.AddPublicBundle(bundle1) + s.Require().NoError(err) + + // Create & add a new bundle + bundleContainer, err = NewBundleContainer(key, "1") + s.Require().NoError(err) + + bundle2 := bundleContainer.GetBundle() + // We set the version + bundle2.GetSignedPreKeys()["1"].Version = 1 + + err = s.service.AddPublicBundle(bundle2) + s.Require().NoError(err) + + // Add again the initial bundle + err = s.service.AddPublicBundle(bundle1) + s.Require().NoError(err) + + actualBundle, err = s.service.GetPublicBundle(&key.PublicKey) + s.Require().NoError(err) + s.Equal(bundle2.GetIdentity(), actualBundle.GetIdentity(), "It sets the right identity") + s.Equal(bundle2.GetSignedPreKeys()["1"].GetVersion(), uint32(1)) + s.Equal(bundle2.GetSignedPreKeys()["1"].GetSignedPreKey(), actualBundle.GetSignedPreKeys()["1"].GetSignedPreKey(), "It sets the right prekeys") +} + func (s *SQLLitePersistenceTestSuite) TestMultiplePublicBundle() { key, err := crypto.GenerateKey() s.Require().NoError(err) actualBundle, err := s.service.GetPublicBundle(&key.PublicKey) - s.Require().NoError(err, "It does not return an error if the bundle is not there") + s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Nil(actualBundle) bundleContainer, err := NewBundleContainer(key, "1") @@ -120,8 +190,10 @@ func (s *SQLLitePersistenceTestSuite) TestMultiplePublicBundle() { // Adding a different bundle bundleContainer, err = NewBundleContainer(key, "1") s.Require().NoError(err) - + // We set the version bundle = bundleContainer.GetBundle() + bundle.GetSignedPreKeys()["1"].Version = 1 + err = s.service.AddPublicBundle(bundle) s.Require().NoError(err) @@ -139,7 +211,7 @@ func (s *SQLLitePersistenceTestSuite) TestMultiDevicePublicBundle() { s.Require().NoError(err) actualBundle, err := s.service.GetPublicBundle(&key.PublicKey) - s.Require().NoError(err, "It does not return an error if the bundle is not there") + s.Require().NoError(err, "Error was not returned even though bundle is not there") s.Nil(actualBundle) bundleContainer, err := NewBundleContainer(key, "1") @@ -273,4 +345,3 @@ func (s *SQLLitePersistenceTestSuite) TestRatchetInfoNoBundle() { } // TODO: Add test for MarkBundleExpired -// TODO: Add test for AddPublicBundle checking that it expires previous bundles diff --git a/static/bindata.go b/static/bindata.go index 2de9e4523..47851caa0 100644 --- a/static/bindata.go +++ b/static/bindata.go @@ -88,7 +88,7 @@ func ConfigCliFleetEthBetaJson() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "../config/cli/fleet-eth.beta.json", size: 3237, mode: os.FileMode(420), modTime: time.Unix(1537514234, 0)} + info := bindataFileInfo{name: "../config/cli/fleet-eth.beta.json", size: 3237, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -108,7 +108,7 @@ func ConfigCliFleetEthStagingJson() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "../config/cli/fleet-eth.staging.json", size: 1838, mode: os.FileMode(420), modTime: time.Unix(1537514234, 0)} + info := bindataFileInfo{name: "../config/cli/fleet-eth.staging.json", size: 1838, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -128,7 +128,7 @@ func ConfigCliFleetEthTestJson() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "../config/cli/fleet-eth.test.json", size: 1519, mode: os.FileMode(420), modTime: time.Unix(1537514234, 0)} + info := bindataFileInfo{name: "../config/cli/fleet-eth.test.json", size: 1519, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -148,7 +148,7 @@ func ConfigCliLesEnabledJson() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "../config/cli/les-enabled.json", size: 58, mode: os.FileMode(420), modTime: time.Unix(1536858252, 0)} + info := bindataFileInfo{name: "../config/cli/les-enabled.json", size: 58, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -168,7 +168,7 @@ func ConfigCliMailserverEnabledJson() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "../config/cli/mailserver-enabled.json", size: 176, mode: os.FileMode(420), modTime: time.Unix(1538032850, 0)} + info := bindataFileInfo{name: "../config/cli/mailserver-enabled.json", size: 176, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -188,7 +188,7 @@ func ConfigStatusChainGenesisJson() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "../config/status-chain-genesis.json", size: 612, mode: os.FileMode(420), modTime: time.Unix(1536858252, 0)} + info := bindataFileInfo{name: "../config/status-chain-genesis.json", size: 612, mode: os.FileMode(420), modTime: time.Unix(1539606161, 0)} a := &asset{bytes: bytes, info: info} return a, nil } diff --git a/static/migrations/1540715431_add_version.down.sql b/static/migrations/1540715431_add_version.down.sql new file mode 100644 index 000000000..a6bc43c51 --- /dev/null +++ b/static/migrations/1540715431_add_version.down.sql @@ -0,0 +1,3 @@ +ALTER TABLE keys DROP COLUMN session_id; +ALTER TABLE sessions DROP COLUMN keys_count; +ALTER TABLE bundles DROP COLUMN version; diff --git a/static/migrations/1540715431_add_version.up.sql b/static/migrations/1540715431_add_version.up.sql new file mode 100644 index 000000000..574eb99e5 --- /dev/null +++ b/static/migrations/1540715431_add_version.up.sql @@ -0,0 +1,5 @@ +DELETE FROM keys; +ALTER TABLE keys ADD COLUMN seq_num INTEGER NOT NULL DEFAULT 0; +ALTER TABLE keys ADD COLUMN session_id BLOB; +ALTER TABLE sessions ADD COLUMN keys_count INTEGER NOT NULL DEFAULT 0; +ALTER TABLE bundles ADD COLUMN version INTEGER NOT NULL DEFAULT 0; diff --git a/vendor/github.com/status-im/doubleratchet/keys_storage.go b/vendor/github.com/status-im/doubleratchet/keys_storage.go index 82dec2e9d..ba6799659 100644 --- a/vendor/github.com/status-im/doubleratchet/keys_storage.go +++ b/vendor/github.com/status-im/doubleratchet/keys_storage.go @@ -1,18 +1,26 @@ package doubleratchet +import ( + "bytes" + "sort" +) + // KeysStorage is an interface of an abstract in-memory or persistent keys storage. type KeysStorage interface { // Get returns a message key by the given key and message number. Get(k Key, msgNum uint) (mk Key, ok bool, err error) // Put saves the given mk under the specified key and msgNum. - Put(k Key, msgNum uint, mk Key) error + Put(sessionID []byte, k Key, msgNum uint, mk Key, keySeqNum uint) error // DeleteMk ensures there's no message key under the specified key and msgNum. DeleteMk(k Key, msgNum uint) error - // DeletePk ensures there's no message keys under the specified key. - DeletePk(k Key) error + // DeleteOldMKeys deletes old message keys for a session. + DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error + + // TruncateMks truncates the number of keys to maxKeys. + TruncateMks(sessionID []byte, maxKeys int) error // Count returns number of message keys stored under the specified key. Count(k Key) (uint, error) @@ -23,10 +31,10 @@ type KeysStorage interface { // KeysStorageInMemory is an in-memory message keys storage. type KeysStorageInMemory struct { - keys map[Key]map[uint]Key + keys map[Key]map[uint]InMemoryKey } -// See KeysStorage. +// Get returns a message key by the given key and message number. func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) { if s.keys == nil { return Key{}, false, nil @@ -39,22 +47,32 @@ func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) { if !ok { return Key{}, false, nil } - return mk, true, nil + return mk.messageKey, true, nil } -// See KeysStorage. -func (s *KeysStorageInMemory) Put(pubKey Key, msgNum uint, mk Key) error { +type InMemoryKey struct { + messageKey Key + seqNum uint + sessionID []byte +} + +// Put saves the given mk under the specified key and msgNum. +func (s *KeysStorageInMemory) Put(sessionID []byte, pubKey Key, msgNum uint, mk Key, seqNum uint) error { if s.keys == nil { - s.keys = make(map[Key]map[uint]Key) + s.keys = make(map[Key]map[uint]InMemoryKey) } if _, ok := s.keys[pubKey]; !ok { - s.keys[pubKey] = make(map[uint]Key) + s.keys[pubKey] = make(map[uint]InMemoryKey) + } + s.keys[pubKey][msgNum] = InMemoryKey{ + sessionID: sessionID, + messageKey: mk, + seqNum: seqNum, } - s.keys[pubKey][msgNum] = mk return nil } -// See KeysStorage. +// DeleteMk ensures there's no message key under the specified key and msgNum. func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error { if s.keys == nil { return nil @@ -72,19 +90,58 @@ func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error { return nil } -// See KeysStorage. -func (s *KeysStorageInMemory) DeletePk(pubKey Key) error { - if s.keys == nil { +// TruncateMks truncates the number of keys to maxKeys. +func (s *KeysStorageInMemory) TruncateMks(sessionID []byte, maxKeys int) error { + var seqNos []uint + // Collect all seq numbers + for _, keys := range s.keys { + for _, inMemoryKey := range keys { + if bytes.Equal(inMemoryKey.sessionID, sessionID) { + seqNos = append(seqNos, inMemoryKey.seqNum) + } + } + } + + // Nothing to do if we haven't reached the limit + if len(seqNos) <= maxKeys { return nil } - if _, ok := s.keys[pubKey]; !ok { - return nil + + // Take the sequence numbers we care about + sort.Slice(seqNos, func(i, j int) bool { return seqNos[i] < seqNos[j] }) + toDeleteSlice := seqNos[:len(seqNos)-maxKeys] + + // Put in map for easier lookup + toDelete := make(map[uint]bool) + + for _, seqNo := range toDeleteSlice { + toDelete[seqNo] = true } - delete(s.keys, pubKey) + + for pubKey, keys := range s.keys { + for i, inMemoryKey := range keys { + if toDelete[inMemoryKey.seqNum] && bytes.Equal(inMemoryKey.sessionID, sessionID) { + delete(s.keys[pubKey], i) + } + } + } + return nil } -// See KeysStorage. +// DeleteOldMKeys deletes old message keys for a session. +func (s *KeysStorageInMemory) DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error { + for pubKey, keys := range s.keys { + for i, inMemoryKey := range keys { + if inMemoryKey.seqNum <= deleteUntilSeqKey && bytes.Equal(inMemoryKey.sessionID, sessionID) { + delete(s.keys[pubKey], i) + } + } + } + return nil +} + +// Count returns number of message keys stored under the specified key. func (s *KeysStorageInMemory) Count(pubKey Key) (uint, error) { if s.keys == nil { return 0, nil @@ -92,7 +149,16 @@ func (s *KeysStorageInMemory) Count(pubKey Key) (uint, error) { return uint(len(s.keys[pubKey])), nil } -// See KeysStorage. +// All returns all the keys func (s *KeysStorageInMemory) All() (map[Key]map[uint]Key, error) { - return s.keys, nil + response := make(map[Key]map[uint]Key) + + for pubKey, keys := range s.keys { + response[pubKey] = make(map[uint]Key) + for n, key := range keys { + response[pubKey][n] = key.messageKey + } + } + + return response, nil } diff --git a/vendor/github.com/status-im/doubleratchet/options.go b/vendor/github.com/status-im/doubleratchet/options.go index 1c740bf66..d3d6d4e5c 100644 --- a/vendor/github.com/status-im/doubleratchet/options.go +++ b/vendor/github.com/status-im/doubleratchet/options.go @@ -17,7 +17,7 @@ func WithMaxSkip(n int) option { } } -// WithMaxKeep specifies the maximum number of ratchet steps before a message is deleted. +// WithMaxKeep specifies how long we keep message keys, counted in number of messages received // nolint: golint func WithMaxKeep(n int) option { return func(s *State) error { @@ -29,6 +29,18 @@ func WithMaxKeep(n int) option { } } +// WithMaxMessageKeysPerSession specifies the maximum number of message keys per session +// nolint: golint +func WithMaxMessageKeysPerSession(n int) option { + return func(s *State) error { + if n < 0 { + return fmt.Errorf("n must be non-negative") + } + s.MaxMessageKeysPerSession = n + return nil + } +} + // WithKeysStorage replaces the default keys storage with the specified. // nolint: golint func WithKeysStorage(ks KeysStorage) option { diff --git a/vendor/github.com/status-im/doubleratchet/session.go b/vendor/github.com/status-im/doubleratchet/session.go index 4da80e2c9..71f2fe48a 100644 --- a/vendor/github.com/status-im/doubleratchet/session.go +++ b/vendor/github.com/status-im/doubleratchet/session.go @@ -130,7 +130,6 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) { ) // Is there a new ratchet key? - isDHStepped := false if m.Header.DH != sc.DHr { if skippedKeys1, err = sc.skipMessageKeys(sc.DHr, uint(m.Header.PN)); err != nil { return nil, fmt.Errorf("can't skip previous chain message keys: %s", err) @@ -138,7 +137,6 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) { if err = sc.dhRatchet(m.Header); err != nil { return nil, fmt.Errorf("can't perform ratchet step: %s", err) } - isDHStepped = true } // After all, update the current chain. @@ -151,16 +149,12 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) { return nil, fmt.Errorf("can't decrypt: %s", err) } - // Apply changes. - if err := s.applyChanges(sc, append(skippedKeys1, skippedKeys2...)); err != nil { - return nil, err - } + // Increment the number of keys + sc.KeysCount++ - if isDHStepped { - err = s.deleteSkippedKeys(s.DHr) - if err != nil { - return nil, err - } + // Apply changes. + if err := s.applyChanges(sc, s.id, append(skippedKeys1, skippedKeys2...)); err != nil { + return nil, err } // Store state diff --git a/vendor/github.com/status-im/doubleratchet/session_he.go b/vendor/github.com/status-im/doubleratchet/session_he.go index 32f7e2cf7..0d2221b68 100644 --- a/vendor/github.com/status-im/doubleratchet/session_he.go +++ b/vendor/github.com/status-im/doubleratchet/session_he.go @@ -103,12 +103,9 @@ func (s *sessionHE) RatchetDecrypt(m MessageHE, ad []byte) ([]byte, error) { return nil, fmt.Errorf("can't decrypt: %s", err) } - if err = s.applyChanges(sc, append(skippedKeys1, skippedKeys2...)); err != nil { + if err = s.applyChanges(sc, []byte("FIXME"), append(skippedKeys1, skippedKeys2...)); err != nil { return nil, fmt.Errorf("failed to apply changes: %s", err) } - if step { - _ = s.deleteSkippedKeys(s.HKr) - } return plaintext, nil } diff --git a/vendor/github.com/status-im/doubleratchet/state.go b/vendor/github.com/status-im/doubleratchet/state.go index 69613cdae..f3fe734a8 100644 --- a/vendor/github.com/status-im/doubleratchet/state.go +++ b/vendor/github.com/status-im/doubleratchet/state.go @@ -42,14 +42,18 @@ type State struct { // Sending header key and next header key. Only used for header encryption. HKs, NHKs Key - // Number of ratchet steps after which all skipped message keys for that key will be deleted. + // How long we keep messages keys, counted in number of messages received, + // for example if MaxKeep is 5 we only keep the last 5 messages keys, deleting everything n - 5. MaxKeep uint + // Max number of message keys per session, older keys will be deleted in FIFO fashion + MaxMessageKeysPerSession int + // The number of the current ratchet step. Step uint - // Which key for the receiving chain was used at the specified step. - DeleteKeys map[uint]Key + // KeysCount the number of keys generated for decrypting + KeysCount uint } func DefaultState(sharedKey Key) State { @@ -61,12 +65,13 @@ func DefaultState(sharedKey Key) State { RootCh: kdfRootChain{CK: sharedKey, Crypto: c}, // Populate CKs and CKr with sharedKey so that both parties could send and receive // messages from the very beginning. - SendCh: kdfChain{CK: sharedKey, Crypto: c}, - RecvCh: kdfChain{CK: sharedKey, Crypto: c}, - MkSkipped: &KeysStorageInMemory{}, - MaxSkip: 1000, - MaxKeep: 100, - DeleteKeys: make(map[uint]Key), + SendCh: kdfChain{CK: sharedKey, Crypto: c}, + RecvCh: kdfChain{CK: sharedKey, Crypto: c}, + MkSkipped: &KeysStorageInMemory{}, + MaxSkip: 1000, + MaxMessageKeysPerSession: 2000, + MaxKeep: 2000, + KeysCount: 0, } } @@ -112,6 +117,7 @@ type skippedKey struct { key Key nr uint mk Key + seq uint } // skipMessageKeys skips message keys in the current receiving chain. @@ -119,14 +125,11 @@ func (s *State) skipMessageKeys(key Key, until uint) ([]skippedKey, error) { if until < uint(s.RecvCh.N) { return nil, fmt.Errorf("bad until: probably an out-of-order message that was deleted") } - nSkipped, err := s.MkSkipped.Count(key) - if err != nil { - return nil, err - } - if until-uint(s.RecvCh.N)+nSkipped > s.MaxSkip { + if uint(s.RecvCh.N)+s.MaxSkip < until { return nil, fmt.Errorf("too many messages") } + skipped := []skippedKey{} for uint(s.RecvCh.N) < until { mk := s.RecvCh.step() @@ -134,32 +137,31 @@ func (s *State) skipMessageKeys(key Key, until uint) ([]skippedKey, error) { key: key, nr: uint(s.RecvCh.N - 1), mk: mk, + seq: s.KeysCount, }) + // Increment key count + s.KeysCount++ + } return skipped, nil } -func (s *State) applyChanges(sc State, skipped []skippedKey) error { +func (s *State) applyChanges(sc State, sessionID []byte, skipped []skippedKey) error { *s = sc for _, skipped := range skipped { - if err := s.MkSkipped.Put(skipped.key, skipped.nr, skipped.mk); err != nil { + if err := s.MkSkipped.Put(sessionID, skipped.key, skipped.nr, skipped.mk, skipped.seq); err != nil { + return err + } + } + + if err := s.MkSkipped.TruncateMks(sessionID, s.MaxMessageKeysPerSession); err != nil { + return err + } + if s.KeysCount >= s.MaxKeep { + if err := s.MkSkipped.DeleteOldMks(sessionID, s.KeysCount-s.MaxKeep); err != nil { return err } } return nil } - -func (s *State) deleteSkippedKeys(key Key) error { - - s.DeleteKeys[s.Step] = key - s.Step++ - if hk, ok := s.DeleteKeys[s.Step-s.MaxKeep]; ok { - if err := s.MkSkipped.DeletePk(hk); err != nil { - return err - } - - delete(s.DeleteKeys, s.Step-s.MaxKeep) - } - return nil -}