MaxSkip is now handled correctly: no more than MaxSkip messages is allowed in a single chain

This commit is contained in:
Ivan Tomilov 2017-06-13 23:46:11 +07:00
parent c284ef8697
commit 25eacbff6a
4 changed files with 93 additions and 27 deletions

View File

@ -30,3 +30,6 @@ type DHPair interface {
PrivateKey() [32]byte
PublicKey() [32]byte
}
// TODO:
// type Key [32]byte

69
keys_storage.go Normal file
View File

@ -0,0 +1,69 @@
package doubleratchet
// KeysStorage is an interface of an abstract in-memory or persistent keys storage.
type KeysStorage interface {
// Get returns a message key by the given public key and message number.
Get(pubKey [32]byte, msgNum uint) (mk [32]byte, ok bool)
// Put saves the given mk under the specified pubKey and msgNum.
Put(pubKey [32]byte, msgNum uint, mk [32]byte)
// Delete ensures there's no message key under the specified pubKey and msgNum.
Delete(pubKey [32]byte, msgNum uint)
// Count returns number of message keys stored under pubKey.
Count(pubKey [32]byte) uint
}
// KeysStorageInMemory is an in-memory message keys storage.
type KeysStorageInMemory struct {
keys map[[32]byte]map[uint][32]byte
}
func (s *KeysStorageInMemory) Get(pubKey [32]byte, msgNum uint) ([32]byte, bool) {
if s.keys == nil {
s.keys = make(map[[32]byte]map[uint][32]byte)
}
msgs, ok := s.keys[pubKey]
if !ok {
return [32]byte{}, false
}
mk, ok := msgs[msgNum]
if !ok {
return [32]byte{}, false
}
return mk, true
}
func (s *KeysStorageInMemory) Put(pubKey [32]byte, msgNum uint, mk [32]byte) {
if s.keys == nil {
s.keys = make(map[[32]byte]map[uint][32]byte)
}
if _, ok := s.keys[pubKey]; !ok {
s.keys[pubKey] = make(map[uint][32]byte)
}
s.keys[pubKey][msgNum] = mk
}
func (s *KeysStorageInMemory) Delete(pubKey [32]byte, msgNum uint) {
if s.keys == nil {
return
}
if _, ok := s.keys[pubKey]; !ok {
return
}
if _, ok := s.keys[pubKey][msgNum]; !ok {
return
}
delete(s.keys[pubKey], msgNum)
if len(s.keys[pubKey]) == 0 {
delete(s.keys, pubKey)
}
}
func (s *KeysStorageInMemory) Count(pubKey [32]byte) uint {
if s.keys == nil {
return 0
}
return uint(len(s.keys[pubKey]))
}

View File

@ -4,8 +4,6 @@ package doubleratchet
// a number of events (messages received, DH ratchet steps, etc.). It's better to use some
// deterministic measure.
// FIXME: Correct MaxSkip handling for message numbers like: 1, 3, 5
// TODO: During each DH ratchet step a new ratchet key pair and sending chain are generated.
// As the sending chain is not needed right away, these steps could be deferred until the party
// is about to send a new message.
@ -13,7 +11,6 @@ package doubleratchet
// TODO: Think if to truncate an authentication tag to 128 bits.
import (
"encoding/hex"
"fmt"
)
@ -50,7 +47,7 @@ type state struct {
PN uint
// Dictionary of skipped-over message keys, indexed by ratchet public key and message number.
MkSkipped map[string][32]byte
MkSkipped KeysStorage
// MaxSkip should be set high enough to tolerate routine lost or delayed messages,
// but low enough that a malicious sender can't trigger excessive recipient computation.
@ -70,7 +67,7 @@ func New(sharedKey [32]byte, opts ...Option) (State, error) {
RK: sharedKey,
CKs: sharedKey, // Populate CKs and CKr with sharedKey as per specification so that both
CKr: sharedKey, // parties could both send and receive messages from the very beginning.
MkSkipped: make(map[string][32]byte),
MkSkipped: &KeysStorageInMemory{},
MaxSkip: 1000,
Crypto: DefaultCrypto{},
}
@ -113,6 +110,9 @@ func WithMaxSkip(n int) Option {
}
}
// TODO: WithKeysStorage.
// TODO: WithCrypto.
// RatchetEncrypt performs a symmetric-key ratchet step, then encrypts the message with
// the resulting message key.
func (s *state) RatchetEncrypt(plaintext []byte, ad AssociatedData) Message {
@ -174,33 +174,27 @@ func (s *state) RatchetDecrypt(m Message, ad AssociatedData) ([]byte, error) {
// trySkippedMessageKeys tries to decrypt the message with a skipped message key.
func (s *state) trySkippedMessageKeys(m Message, ad AssociatedData) ([]byte, error) {
k := s.skippedKey(m.Header.DH[:], m.Header.N)
if mk, ok := s.MkSkipped[k]; ok {
if mk, ok := s.MkSkipped.Get(m.Header.DH, m.Header.N); ok {
plaintext, err := s.Crypto.Decrypt(mk, m.Ciphertext, m.Header.EncodeWithAD(ad))
if err != nil {
return nil, fmt.Errorf("can't decrypt message: %s", err)
}
delete(s.MkSkipped, k)
s.MkSkipped.Delete(m.Header.DH, m.Header.N)
return plaintext, nil
}
return nil, nil
}
// skippedKey forms a key for a skipped message.
func (s *state) skippedKey(dh []byte, n uint) string {
return fmt.Sprintf("%s%d", hex.EncodeToString(dh), n)
}
// skipMessageKeys skips message keys in the current receiving chain.
func (s *state) skipMessageKeys(until uint) error {
// until exceeds the number of messages in the receiving chain for no more than s.MaxSkip
if s.Nr+s.MaxSkip < until {
return fmt.Errorf("too many messages: %d", until-s.Nr)
nSkipped := s.MkSkipped.Count(s.DHr)
if until-s.Nr+nSkipped > s.MaxSkip {
return fmt.Errorf("too many messages")
}
for s.Nr < until {
var mk [32]byte
s.CKr, mk = s.Crypto.KdfCK(s.CKr)
s.MkSkipped[s.skippedKey(s.DHr[:], s.Nr)] = mk
s.MkSkipped.Put(s.DHr, s.Nr, mk)
s.Nr++
}
return nil

View File

@ -198,29 +198,29 @@ func TestState_RatchetDecrypt_CommunicationSkippedMessages(t *testing.T) {
t.Run("skipped messages from alice", func(t *testing.T) {
// Arrange.
var (
m1 = alice.RatchetEncrypt([]byte("hi"), nil)
m2 = alice.RatchetEncrypt([]byte("bob"), nil)
m3 = alice.RatchetEncrypt([]byte("how are you?"), nil)
m4 = alice.RatchetEncrypt([]byte("still do cryptography?"), nil)
m0 = alice.RatchetEncrypt([]byte("hi"), nil)
m1 = alice.RatchetEncrypt([]byte("bob"), nil)
m2 = alice.RatchetEncrypt([]byte("how are you?"), nil)
m3 = alice.RatchetEncrypt([]byte("still do cryptography?"), nil)
)
// Act and assert.
d, err := bob.RatchetDecrypt(m2, nil) // Decrypted and skipped.
d, err := bob.RatchetDecrypt(m1, nil) // Decrypted and skipped.
require.Nil(t, err)
require.Equal(t, []byte("bob"), d)
_, err = bob.RatchetDecrypt(m4, nil) // Error: too many to skip.
require.Nil(t, err)
_, err = bob.RatchetDecrypt(m3, nil) // Error: too many to skip.
require.NotNil(t, err)
d, err = bob.RatchetDecrypt(m3, nil) // Decrypted.
d, err = bob.RatchetDecrypt(m2, nil) // Decrypted.
require.Nil(t, err)
require.Equal(t, []byte("how are you?"), d)
d, err = bob.RatchetDecrypt(m4, nil) // Decrypted.
d, err = bob.RatchetDecrypt(m3, nil) // Decrypted.
require.Nil(t, err)
require.Equal(t, []byte("still do cryptography?"), d)
d, err = bob.RatchetDecrypt(m1, nil) // Decrypted.
d, err = bob.RatchetDecrypt(m0, nil) // Decrypted.
require.Nil(t, err)
require.Equal(t, []byte("hi"), d)
})