Change handling of skipped/deleted keys
The purpose of limiting the number of skipped keys generated is to avoid a dos attack whereby an attacker would send a large N, forcing the device to compute all the keys between currentN..N . Previously the logic for handling skipped keys was: - If in the current receiving chain there are more than maxSkip keys, throw an error This is problematic as in long-lived session dropped/unreceived messages starts piling up, eventually reaching the threshold (1000 dropped/unreceived messages). This logic has been changed to be more inline with signals spec, and now it is: - If N is > currentN + maxSkip, throw an error The purpose of limiting the number of skipped keys stored is to avoid a dos attack whereby an attacker would force us to store a large number of keys, filling up our storage. Previously the logic for handling old keys was: - Once you have maxKeep ratchet steps, delete any key from currentRatchet - maxKeep. This, in combination with the maxSkip implementation, capped the number of stored keys to maxSkip * maxKeep. The logic has been changed to: - Keep a maximum of MaxMessageKeysPerSession and additionally we delete any key that has a sequence number < currentSeqNum - maxKeep
This commit is contained in:
parent
c243ae5a66
commit
7279c44c22
108
keys_storage.go
108
keys_storage.go
|
@ -1,18 +1,26 @@
|
|||
package doubleratchet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// KeysStorage is an interface of an abstract in-memory or persistent keys storage.
|
||||
type KeysStorage interface {
|
||||
// Get returns a message key by the given key and message number.
|
||||
Get(k Key, msgNum uint) (mk Key, ok bool, err error)
|
||||
|
||||
// Put saves the given mk under the specified key and msgNum.
|
||||
Put(k Key, msgNum uint, mk Key) error
|
||||
Put(sessionID []byte, k Key, msgNum uint, mk Key, keySeqNum uint) error
|
||||
|
||||
// DeleteMk ensures there's no message key under the specified key and msgNum.
|
||||
DeleteMk(k Key, msgNum uint) error
|
||||
|
||||
// DeletePk ensures there's no message keys under the specified key.
|
||||
DeletePk(k Key) error
|
||||
// DeleteOldMKeys deletes old message keys for a session.
|
||||
DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error
|
||||
|
||||
// TruncateMks truncates the number of keys to maxKeys.
|
||||
TruncateMks(sessionID []byte, maxKeys int) error
|
||||
|
||||
// Count returns number of message keys stored under the specified key.
|
||||
Count(k Key) (uint, error)
|
||||
|
@ -23,10 +31,10 @@ type KeysStorage interface {
|
|||
|
||||
// KeysStorageInMemory is an in-memory message keys storage.
|
||||
type KeysStorageInMemory struct {
|
||||
keys map[Key]map[uint]Key
|
||||
keys map[Key]map[uint]InMemoryKey
|
||||
}
|
||||
|
||||
// See KeysStorage.
|
||||
// Get returns a message key by the given key and message number.
|
||||
func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) {
|
||||
if s.keys == nil {
|
||||
return Key{}, false, nil
|
||||
|
@ -39,22 +47,32 @@ func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) {
|
|||
if !ok {
|
||||
return Key{}, false, nil
|
||||
}
|
||||
return mk, true, nil
|
||||
return mk.messageKey, true, nil
|
||||
}
|
||||
|
||||
// See KeysStorage.
|
||||
func (s *KeysStorageInMemory) Put(pubKey Key, msgNum uint, mk Key) error {
|
||||
type InMemoryKey struct {
|
||||
messageKey Key
|
||||
seqNum uint
|
||||
sessionID []byte
|
||||
}
|
||||
|
||||
// Put saves the given mk under the specified key and msgNum.
|
||||
func (s *KeysStorageInMemory) Put(sessionID []byte, pubKey Key, msgNum uint, mk Key, seqNum uint) error {
|
||||
if s.keys == nil {
|
||||
s.keys = make(map[Key]map[uint]Key)
|
||||
s.keys = make(map[Key]map[uint]InMemoryKey)
|
||||
}
|
||||
if _, ok := s.keys[pubKey]; !ok {
|
||||
s.keys[pubKey] = make(map[uint]Key)
|
||||
s.keys[pubKey] = make(map[uint]InMemoryKey)
|
||||
}
|
||||
s.keys[pubKey][msgNum] = InMemoryKey{
|
||||
sessionID: sessionID,
|
||||
messageKey: mk,
|
||||
seqNum: seqNum,
|
||||
}
|
||||
s.keys[pubKey][msgNum] = mk
|
||||
return nil
|
||||
}
|
||||
|
||||
// See KeysStorage.
|
||||
// DeleteMk ensures there's no message key under the specified key and msgNum.
|
||||
func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error {
|
||||
if s.keys == nil {
|
||||
return nil
|
||||
|
@ -72,19 +90,58 @@ func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// See KeysStorage.
|
||||
func (s *KeysStorageInMemory) DeletePk(pubKey Key) error {
|
||||
if s.keys == nil {
|
||||
// TruncateMks truncates the number of keys to maxKeys.
|
||||
func (s *KeysStorageInMemory) TruncateMks(sessionID []byte, maxKeys int) error {
|
||||
var seqNos []uint
|
||||
// Collect all seq numbers
|
||||
for _, keys := range s.keys {
|
||||
for _, inMemoryKey := range keys {
|
||||
if bytes.Equal(inMemoryKey.sessionID, sessionID) {
|
||||
seqNos = append(seqNos, inMemoryKey.seqNum)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing to do if we haven't reached the limit
|
||||
if len(seqNos) <= maxKeys {
|
||||
return nil
|
||||
}
|
||||
if _, ok := s.keys[pubKey]; !ok {
|
||||
return nil
|
||||
|
||||
// Take the sequence numbers we care about
|
||||
sort.Slice(seqNos, func(i, j int) bool { return seqNos[i] < seqNos[j] })
|
||||
toDeleteSlice := seqNos[:len(seqNos)-maxKeys]
|
||||
|
||||
// Put in map for easier lookup
|
||||
toDelete := make(map[uint]bool)
|
||||
|
||||
for _, seqNo := range toDeleteSlice {
|
||||
toDelete[seqNo] = true
|
||||
}
|
||||
delete(s.keys, pubKey)
|
||||
|
||||
for pubKey, keys := range s.keys {
|
||||
for i, inMemoryKey := range keys {
|
||||
if toDelete[inMemoryKey.seqNum] && bytes.Equal(inMemoryKey.sessionID, sessionID) {
|
||||
delete(s.keys[pubKey], i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// See KeysStorage.
|
||||
// DeleteOldMKeys deletes old message keys for a session.
|
||||
func (s *KeysStorageInMemory) DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error {
|
||||
for pubKey, keys := range s.keys {
|
||||
for i, inMemoryKey := range keys {
|
||||
if inMemoryKey.seqNum <= deleteUntilSeqKey && bytes.Equal(inMemoryKey.sessionID, sessionID) {
|
||||
delete(s.keys[pubKey], i)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count returns number of message keys stored under the specified key.
|
||||
func (s *KeysStorageInMemory) Count(pubKey Key) (uint, error) {
|
||||
if s.keys == nil {
|
||||
return 0, nil
|
||||
|
@ -92,7 +149,16 @@ func (s *KeysStorageInMemory) Count(pubKey Key) (uint, error) {
|
|||
return uint(len(s.keys[pubKey])), nil
|
||||
}
|
||||
|
||||
// See KeysStorage.
|
||||
// All returns all the keys
|
||||
func (s *KeysStorageInMemory) All() (map[Key]map[uint]Key, error) {
|
||||
return s.keys, nil
|
||||
response := make(map[Key]map[uint]Key)
|
||||
|
||||
for pubKey, keys := range s.keys {
|
||||
response[pubKey] = make(map[uint]Key)
|
||||
for n, key := range keys {
|
||||
response[pubKey][n] = key.messageKey
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ func TestKeysStorageInMemory_Put(t *testing.T) {
|
|||
ks := &KeysStorageInMemory{}
|
||||
|
||||
// Act and assert.
|
||||
err := ks.Put(pubKey1, 0, mk)
|
||||
err := ks.Put([]byte("session-id"), pubKey1, 0, mk, 1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
@ -58,15 +58,9 @@ func TestKeysStorageInMemory_Flow(t *testing.T) {
|
|||
// Arrange.
|
||||
ks := &KeysStorageInMemory{}
|
||||
|
||||
t.Run("delete non-existent pubkey", func(t *testing.T) {
|
||||
// Act and assert.
|
||||
err := ks.DeletePk(pubKey1)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("put and get existing", func(t *testing.T) {
|
||||
// Act.
|
||||
err := ks.Put(pubKey1, 0, mk)
|
||||
err := ks.Put([]byte("session-id"), pubKey1, 0, mk, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
k, ok, err := ks.Get(pubKey1, 0)
|
||||
|
@ -138,32 +132,4 @@ func TestKeysStorageInMemory_Flow(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.EqualValues(t, 0, cnt)
|
||||
})
|
||||
|
||||
t.Run("delete existing pubkey", func(t *testing.T) {
|
||||
// Act.
|
||||
err := ks.Put(pubKey1, 0, mk)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ks.Put(pubKey2, 0, mk)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ks.DeletePk(pubKey1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ks.DeletePk(pubKey1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ks.DeletePk(pubKey2)
|
||||
require.NoError(t, err)
|
||||
|
||||
cn1, err := ks.Count(pubKey1)
|
||||
require.NoError(t, err)
|
||||
|
||||
cn2, err := ks.Count(pubKey2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert.
|
||||
require.Empty(t, cn1)
|
||||
require.Empty(t, cn2)
|
||||
})
|
||||
}
|
||||
|
|
14
options.go
14
options.go
|
@ -17,7 +17,7 @@ func WithMaxSkip(n int) option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithMaxKeep specifies the maximum number of ratchet steps before a message is deleted.
|
||||
// WithMaxKeep specifies how long we keep message keys, counted in number of messages received
|
||||
// nolint: golint
|
||||
func WithMaxKeep(n int) option {
|
||||
return func(s *State) error {
|
||||
|
@ -29,6 +29,18 @@ func WithMaxKeep(n int) option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithMaxMessageKeysPerSession specifies the maximum number of message keys per session
|
||||
// nolint: golint
|
||||
func WithMaxMessageKeysPerSession(n int) option {
|
||||
return func(s *State) error {
|
||||
if n < 0 {
|
||||
return fmt.Errorf("n must be non-negative")
|
||||
}
|
||||
s.MaxMessageKeysPerSession = n
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithKeysStorage replaces the default keys storage with the specified.
|
||||
// nolint: golint
|
||||
func WithKeysStorage(ks KeysStorage) option {
|
||||
|
|
19
session.go
19
session.go
|
@ -115,6 +115,9 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
|
|||
return nil, fmt.Errorf("can't decrypt skipped message: %s", err)
|
||||
}
|
||||
_ = s.MkSkipped.DeleteMk(m.Header.DH, uint(m.Header.N))
|
||||
if err := s.store(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
|
@ -127,7 +130,6 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
|
|||
)
|
||||
|
||||
// Is there a new ratchet key?
|
||||
isDHStepped := false
|
||||
if m.Header.DH != sc.DHr {
|
||||
if skippedKeys1, err = sc.skipMessageKeys(sc.DHr, uint(m.Header.PN)); err != nil {
|
||||
return nil, fmt.Errorf("can't skip previous chain message keys: %s", err)
|
||||
|
@ -135,7 +137,6 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
|
|||
if err = sc.dhRatchet(m.Header); err != nil {
|
||||
return nil, fmt.Errorf("can't perform ratchet step: %s", err)
|
||||
}
|
||||
isDHStepped = true
|
||||
}
|
||||
|
||||
// After all, update the current chain.
|
||||
|
@ -148,16 +149,12 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
|
|||
return nil, fmt.Errorf("can't decrypt: %s", err)
|
||||
}
|
||||
|
||||
// Apply changes.
|
||||
if err := s.applyChanges(sc, append(skippedKeys1, skippedKeys2...)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Increment the number of keys
|
||||
sc.KeysCount++
|
||||
|
||||
if isDHStepped {
|
||||
err = s.deleteSkippedKeys(s.DHr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Apply changes.
|
||||
if err := s.applyChanges(sc, s.id, append(skippedKeys1, skippedKeys2...)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store state
|
||||
|
|
|
@ -103,12 +103,9 @@ func (s *sessionHE) RatchetDecrypt(m MessageHE, ad []byte) ([]byte, error) {
|
|||
return nil, fmt.Errorf("can't decrypt: %s", err)
|
||||
}
|
||||
|
||||
if err = s.applyChanges(sc, append(skippedKeys1, skippedKeys2...)); err != nil {
|
||||
if err = s.applyChanges(sc, []byte("FIXME"), append(skippedKeys1, skippedKeys2...)); err != nil {
|
||||
return nil, fmt.Errorf("failed to apply changes: %s", err)
|
||||
}
|
||||
if step {
|
||||
_ = s.deleteSkippedKeys(s.HKr)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
|
|
@ -134,6 +134,8 @@ func TestSessionHE_RatchetDecrypt_CommunicationSkippedMessages(t *testing.T) {
|
|||
m1 = alice.RatchetEncrypt([]byte("bob"), nil)
|
||||
m2 = alice.RatchetEncrypt([]byte("how are you?"), nil)
|
||||
m3 = alice.RatchetEncrypt([]byte("still do cryptography?"), nil)
|
||||
m4 = alice.RatchetEncrypt([]byte("what up bob?"), nil)
|
||||
m5 = alice.RatchetEncrypt([]byte("bob?"), nil)
|
||||
)
|
||||
|
||||
// Act and assert.
|
||||
|
@ -154,7 +156,7 @@ func TestSessionHE_RatchetDecrypt_CommunicationSkippedMessages(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, bobSkippedCount)
|
||||
|
||||
_, err = bob.RatchetDecrypt(m3, nil) // Error: too many to skip.
|
||||
_, err = bob.RatchetDecrypt(m5, nil) // Error: too many to skip.
|
||||
require.NotNil(t, err)
|
||||
|
||||
d, err = bob.RatchetDecrypt(m2, nil) // Decrypted.
|
||||
|
@ -173,27 +175,78 @@ func TestSessionHE_RatchetDecrypt_CommunicationSkippedMessages(t *testing.T) {
|
|||
d, err = bob.RatchetDecrypt(m0, nil) // Decrypted.
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, []byte("hi"), d)
|
||||
|
||||
d, err = bob.RatchetDecrypt(m4, nil) // Decrypted.
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, []byte("what up bob?"), d)
|
||||
|
||||
d, err = bob.RatchetDecrypt(m5, nil) // Decrypted.
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, []byte("bob?"), d)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionHE_SkippedKeysDeletion(t *testing.T) {
|
||||
func TestSessionHE_OldKeysDeletion(t *testing.T) {
|
||||
// Arrange.
|
||||
var (
|
||||
bob, _ = NewHE(sk, sharedHka, sharedNhkb, bobPair, WithMaxKeep(2))
|
||||
alice, _ = NewHEWithRemoteKey(sk, sharedHka, sharedNhkb, bobPair.PublicKey(), WithMaxKeep(2))
|
||||
h = SessionTestHelperHE{t, alice, bob}
|
||||
)
|
||||
|
||||
// Act.
|
||||
m0 := alice.RatchetEncrypt([]byte("Hi"), nil)
|
||||
|
||||
h.AliceToBob("Bob!", nil) // Bob ratchet step 1.
|
||||
h.BobToAlice("Alice?", nil) // Alice ratchet step 1.
|
||||
h.AliceToBob("How are you?", nil) // Bob ratchet step 2.
|
||||
m0 := alice.RatchetEncrypt([]byte("Hi 1"), nil)
|
||||
m1 := alice.RatchetEncrypt([]byte("Hi 2"), nil)
|
||||
m2 := alice.RatchetEncrypt([]byte("Hi 3"), nil)
|
||||
m3 := alice.RatchetEncrypt([]byte("Hi 4"), nil)
|
||||
|
||||
// Assert.
|
||||
_, err := bob.RatchetDecrypt(m0, nil)
|
||||
|
||||
// This one should be in the db
|
||||
_, err := bob.RatchetDecrypt(m1, nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
// This one should be in the db
|
||||
_, err = bob.RatchetDecrypt(m3, nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
// This key should be discarded
|
||||
_, err = bob.RatchetDecrypt(m0, nil)
|
||||
require.NotNil(t, err)
|
||||
|
||||
// This one should be in the db
|
||||
_, err = bob.RatchetDecrypt(m2, nil)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestSessionHE_ExtraKeysDeletion(t *testing.T) {
|
||||
// Arrange.
|
||||
var (
|
||||
bob, _ = NewHE(sk, sharedHka, sharedNhkb, bobPair, WithMaxMessageKeysPerSession(2))
|
||||
alice, _ = NewHEWithRemoteKey(sk, sharedHka, sharedNhkb, bobPair.PublicKey(), WithMaxMessageKeysPerSession(2))
|
||||
)
|
||||
|
||||
// Act.
|
||||
m0 := alice.RatchetEncrypt([]byte("Hi 1"), nil)
|
||||
m1 := alice.RatchetEncrypt([]byte("Hi 2"), nil)
|
||||
m2 := alice.RatchetEncrypt([]byte("Hi 3"), nil)
|
||||
m3 := alice.RatchetEncrypt([]byte("Hi 4"), nil)
|
||||
|
||||
// Assert.
|
||||
_, err := bob.RatchetDecrypt(m3, nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
// This key should be discarded
|
||||
_, err = bob.RatchetDecrypt(m0, nil)
|
||||
require.NotNil(t, err)
|
||||
|
||||
// This one should be in the db
|
||||
_, err = bob.RatchetDecrypt(m1, nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
// This one should be in the db
|
||||
_, err = bob.RatchetDecrypt(m2, nil)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
type SessionTestHelperHE struct {
|
||||
|
|
|
@ -79,8 +79,8 @@ func TestSession_RatchetEncrypt_Basic(t *testing.T) {
|
|||
func TestSession_RatchetDecrypt_CommunicationFailedWithNoPublicKey(t *testing.T) {
|
||||
// Arrange.
|
||||
var (
|
||||
bob, _ = New([]byte("id"), sk, bobPair, nil)
|
||||
alice, _ = New([]byte("id"), sk, alicePair, nil)
|
||||
bob, _ = New([]byte("bob"), sk, bobPair, nil)
|
||||
alice, _ = New([]byte("alice"), sk, alicePair, nil)
|
||||
)
|
||||
|
||||
// Act.
|
||||
|
@ -96,8 +96,8 @@ func TestSession_RatchetDecrypt_CommunicationFailedWithNoPublicKey(t *testing.T)
|
|||
func TestSession_RatchetDecrypt_CommunicationAliceSends(t *testing.T) {
|
||||
// Arrange.
|
||||
var (
|
||||
bob, _ = New([]byte("id"), sk, bobPair, nil)
|
||||
alice, _ = NewWithRemoteKey([]byte("id"), sk, bobPair.PublicKey(), nil)
|
||||
bob, _ = New([]byte("bob"), sk, bobPair, nil)
|
||||
alice, _ = NewWithRemoteKey([]byte("alice"), sk, bobPair.PublicKey(), nil)
|
||||
)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -111,8 +111,8 @@ func TestSession_RatchetDecrypt_CommunicationAliceSends(t *testing.T) {
|
|||
|
||||
func TestSession_RatchetDecrypt_CommunicationBobSends(t *testing.T) {
|
||||
var (
|
||||
bob, _ = New([]byte("id"), sk, bobPair, nil)
|
||||
alice, _ = NewWithRemoteKey([]byte("id"), sk, bobPair.PublicKey(), nil)
|
||||
bob, _ = New([]byte("bob"), sk, bobPair, nil)
|
||||
alice, _ = NewWithRemoteKey([]byte("alice"), sk, bobPair.PublicKey(), nil)
|
||||
)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -127,8 +127,8 @@ func TestSession_RatchetDecrypt_CommunicationBobSends(t *testing.T) {
|
|||
func TestSession_RatchetDecrypt_CommunicationPingPong(t *testing.T) {
|
||||
// Arrange.
|
||||
var (
|
||||
bob, _ = New([]byte("id"), sk, bobPair, nil)
|
||||
alice, _ = NewWithRemoteKey([]byte("id"), sk, bobPair.PublicKey(), nil)
|
||||
bob, _ = New([]byte("bob"), sk, bobPair, nil)
|
||||
alice, _ = NewWithRemoteKey([]byte("alice"), sk, bobPair.PublicKey(), nil)
|
||||
)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -145,10 +145,10 @@ func TestSession_RatchetDecrypt_CommunicationPingPong(t *testing.T) {
|
|||
func TestSession_RatchetDecrypt_CommunicationSkippedMessages(t *testing.T) {
|
||||
// Arrange.
|
||||
var (
|
||||
bobI, _ = New([]byte("id"), sk, bobPair, nil, WithMaxSkip(1))
|
||||
bobI, _ = New([]byte("bob"), sk, bobPair, nil, WithMaxSkip(1))
|
||||
bob = bobI.(*sessionState)
|
||||
|
||||
aliceI, _ = NewWithRemoteKey([]byte("id"), sk, bob.DHs.PublicKey(), nil, WithMaxSkip(1))
|
||||
aliceI, _ = NewWithRemoteKey([]byte("alice"), sk, bob.DHs.PublicKey(), nil, WithMaxSkip(1))
|
||||
alice = aliceI.(*sessionState)
|
||||
)
|
||||
|
||||
|
@ -166,6 +166,12 @@ func TestSession_RatchetDecrypt_CommunicationSkippedMessages(t *testing.T) {
|
|||
m3, err := alice.RatchetEncrypt([]byte("still do cryptography?"), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
m4, err := alice.RatchetEncrypt([]byte("you there?"), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
m5, err := alice.RatchetEncrypt([]byte("bob? bob? BOB? BOB?"), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Act and assert.
|
||||
m1.Ciphertext[len(m1.Ciphertext)-1] ^= 10
|
||||
_, err = bob.RatchetDecrypt(m1, nil) // Error: invalid signature.
|
||||
|
@ -184,7 +190,7 @@ func TestSession_RatchetDecrypt_CommunicationSkippedMessages(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, bobSkippedCount)
|
||||
|
||||
_, err = bob.RatchetDecrypt(m3, nil) // Error: too many to skip.
|
||||
_, err = bob.RatchetDecrypt(m5, nil) // Too many messages
|
||||
require.NotNil(t, err)
|
||||
|
||||
d, err = bob.RatchetDecrypt(m2, nil) // Decrypted.
|
||||
|
@ -203,14 +209,22 @@ func TestSession_RatchetDecrypt_CommunicationSkippedMessages(t *testing.T) {
|
|||
d, err = bob.RatchetDecrypt(m0, nil) // Decrypted.
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, []byte("hi"), d)
|
||||
|
||||
d, err = bob.RatchetDecrypt(m4, nil) // Decrypted.
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, []byte("you there?"), d)
|
||||
|
||||
d, err = bob.RatchetDecrypt(m5, nil) // Decrypted.
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, []byte("bob? bob? BOB? BOB?"), d)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSession_SkippedKeysDeletion(t *testing.T) {
|
||||
// Arrange.
|
||||
var (
|
||||
bob, _ = New([]byte("id"), sk, bobPair, nil, WithMaxKeep(2))
|
||||
alice, _ = NewWithRemoteKey([]byte("id"), sk, bobPair.PublicKey(), nil, WithMaxKeep(2))
|
||||
bob, _ = New([]byte("bob"), sk, bobPair, nil, WithMaxKeep(2))
|
||||
alice, _ = NewWithRemoteKey([]byte("alice"), sk, bobPair.PublicKey(), nil, WithMaxKeep(2))
|
||||
h = SessionTestHelper{t, alice, bob}
|
||||
)
|
||||
|
||||
|
|
62
state.go
62
state.go
|
@ -42,14 +42,18 @@ type State struct {
|
|||
// Sending header key and next header key. Only used for header encryption.
|
||||
HKs, NHKs Key
|
||||
|
||||
// Number of ratchet steps after which all skipped message keys for that key will be deleted.
|
||||
// How long we keep messages keys, counted in number of messages received,
|
||||
// for example if MaxKeep is 5 we only keep the last 5 messages keys, deleting everything n - 5.
|
||||
MaxKeep uint
|
||||
|
||||
// Max number of message keys per session, older keys will be deleted in FIFO fashion
|
||||
MaxMessageKeysPerSession int
|
||||
|
||||
// The number of the current ratchet step.
|
||||
Step uint
|
||||
|
||||
// Which key for the receiving chain was used at the specified step.
|
||||
DeleteKeys map[uint]Key
|
||||
// KeysCount the number of keys generated for decrypting
|
||||
KeysCount uint
|
||||
}
|
||||
|
||||
func DefaultState(sharedKey Key) State {
|
||||
|
@ -61,12 +65,13 @@ func DefaultState(sharedKey Key) State {
|
|||
RootCh: kdfRootChain{CK: sharedKey, Crypto: c},
|
||||
// Populate CKs and CKr with sharedKey so that both parties could send and receive
|
||||
// messages from the very beginning.
|
||||
SendCh: kdfChain{CK: sharedKey, Crypto: c},
|
||||
RecvCh: kdfChain{CK: sharedKey, Crypto: c},
|
||||
MkSkipped: &KeysStorageInMemory{},
|
||||
MaxSkip: 1000,
|
||||
MaxKeep: 100,
|
||||
DeleteKeys: make(map[uint]Key),
|
||||
SendCh: kdfChain{CK: sharedKey, Crypto: c},
|
||||
RecvCh: kdfChain{CK: sharedKey, Crypto: c},
|
||||
MkSkipped: &KeysStorageInMemory{},
|
||||
MaxSkip: 1000,
|
||||
MaxMessageKeysPerSession: 2000,
|
||||
MaxKeep: 2000,
|
||||
KeysCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,6 +117,7 @@ type skippedKey struct {
|
|||
key Key
|
||||
nr uint
|
||||
mk Key
|
||||
seq uint
|
||||
}
|
||||
|
||||
// skipMessageKeys skips message keys in the current receiving chain.
|
||||
|
@ -119,14 +125,11 @@ func (s *State) skipMessageKeys(key Key, until uint) ([]skippedKey, error) {
|
|||
if until < uint(s.RecvCh.N) {
|
||||
return nil, fmt.Errorf("bad until: probably an out-of-order message that was deleted")
|
||||
}
|
||||
nSkipped, err := s.MkSkipped.Count(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if until-uint(s.RecvCh.N)+nSkipped > s.MaxSkip {
|
||||
if uint(s.RecvCh.N)+s.MaxSkip < until {
|
||||
return nil, fmt.Errorf("too many messages")
|
||||
}
|
||||
|
||||
skipped := []skippedKey{}
|
||||
for uint(s.RecvCh.N) < until {
|
||||
mk := s.RecvCh.step()
|
||||
|
@ -134,32 +137,31 @@ func (s *State) skipMessageKeys(key Key, until uint) ([]skippedKey, error) {
|
|||
key: key,
|
||||
nr: uint(s.RecvCh.N - 1),
|
||||
mk: mk,
|
||||
seq: s.KeysCount,
|
||||
})
|
||||
// Increment key count
|
||||
s.KeysCount++
|
||||
|
||||
}
|
||||
return skipped, nil
|
||||
}
|
||||
|
||||
func (s *State) applyChanges(sc State, skipped []skippedKey) error {
|
||||
func (s *State) applyChanges(sc State, sessionID []byte, skipped []skippedKey) error {
|
||||
*s = sc
|
||||
for _, skipped := range skipped {
|
||||
if err := s.MkSkipped.Put(skipped.key, skipped.nr, skipped.mk); err != nil {
|
||||
if err := s.MkSkipped.Put(sessionID, skipped.key, skipped.nr, skipped.mk, skipped.seq); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.MkSkipped.TruncateMks(sessionID, s.MaxMessageKeysPerSession); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.KeysCount >= s.MaxKeep {
|
||||
if err := s.MkSkipped.DeleteOldMks(sessionID, s.KeysCount-s.MaxKeep); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) deleteSkippedKeys(key Key) error {
|
||||
|
||||
s.DeleteKeys[s.Step] = key
|
||||
s.Step++
|
||||
if hk, ok := s.DeleteKeys[s.Step-s.MaxKeep]; ok {
|
||||
if err := s.MkSkipped.DeletePk(hk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(s.DeleteKeys, s.Step-s.MaxKeep)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -45,7 +45,6 @@ func TestNewState_Basic(t *testing.T) {
|
|||
|
||||
require.NotNil(t, s.MkSkipped)
|
||||
require.NotNil(t, s.Crypto)
|
||||
require.NotNil(t, s.DeleteKeys)
|
||||
}
|
||||
|
||||
func TestNewState_BadSharedKey(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue