MaxSkip is now handled correctly: no more than MaxSkip messages is allowed in a single chain
This commit is contained in:
parent
c284ef8697
commit
25eacbff6a
|
@ -30,3 +30,6 @@ type DHPair interface {
|
|||
PrivateKey() [32]byte
|
||||
PublicKey() [32]byte
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// type Key [32]byte
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
package doubleratchet
|
||||
|
||||
// KeysStorage is an interface of an abstract in-memory or persistent keys storage.
|
||||
type KeysStorage interface {
|
||||
// Get returns a message key by the given public key and message number.
|
||||
Get(pubKey [32]byte, msgNum uint) (mk [32]byte, ok bool)
|
||||
|
||||
// Put saves the given mk under the specified pubKey and msgNum.
|
||||
Put(pubKey [32]byte, msgNum uint, mk [32]byte)
|
||||
|
||||
// Delete ensures there's no message key under the specified pubKey and msgNum.
|
||||
Delete(pubKey [32]byte, msgNum uint)
|
||||
|
||||
// Count returns number of message keys stored under pubKey.
|
||||
Count(pubKey [32]byte) uint
|
||||
}
|
||||
|
||||
// KeysStorageInMemory is an in-memory message keys storage.
|
||||
type KeysStorageInMemory struct {
|
||||
keys map[[32]byte]map[uint][32]byte
|
||||
}
|
||||
|
||||
func (s *KeysStorageInMemory) Get(pubKey [32]byte, msgNum uint) ([32]byte, bool) {
|
||||
if s.keys == nil {
|
||||
s.keys = make(map[[32]byte]map[uint][32]byte)
|
||||
}
|
||||
msgs, ok := s.keys[pubKey]
|
||||
if !ok {
|
||||
return [32]byte{}, false
|
||||
}
|
||||
mk, ok := msgs[msgNum]
|
||||
if !ok {
|
||||
return [32]byte{}, false
|
||||
}
|
||||
return mk, true
|
||||
}
|
||||
|
||||
func (s *KeysStorageInMemory) Put(pubKey [32]byte, msgNum uint, mk [32]byte) {
|
||||
if s.keys == nil {
|
||||
s.keys = make(map[[32]byte]map[uint][32]byte)
|
||||
}
|
||||
if _, ok := s.keys[pubKey]; !ok {
|
||||
s.keys[pubKey] = make(map[uint][32]byte)
|
||||
}
|
||||
s.keys[pubKey][msgNum] = mk
|
||||
}
|
||||
|
||||
func (s *KeysStorageInMemory) Delete(pubKey [32]byte, msgNum uint) {
|
||||
if s.keys == nil {
|
||||
return
|
||||
}
|
||||
if _, ok := s.keys[pubKey]; !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := s.keys[pubKey][msgNum]; !ok {
|
||||
return
|
||||
}
|
||||
delete(s.keys[pubKey], msgNum)
|
||||
if len(s.keys[pubKey]) == 0 {
|
||||
delete(s.keys, pubKey)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeysStorageInMemory) Count(pubKey [32]byte) uint {
|
||||
if s.keys == nil {
|
||||
return 0
|
||||
}
|
||||
return uint(len(s.keys[pubKey]))
|
||||
}
|
28
state.go
28
state.go
|
@ -4,8 +4,6 @@ package doubleratchet
|
|||
// a number of events (messages received, DH ratchet steps, etc.). It's better to use some
|
||||
// deterministic measure.
|
||||
|
||||
// FIXME: Correct MaxSkip handling for message numbers like: 1, 3, 5
|
||||
|
||||
// TODO: During each DH ratchet step a new ratchet key pair and sending chain are generated.
|
||||
// As the sending chain is not needed right away, these steps could be deferred until the party
|
||||
// is about to send a new message.
|
||||
|
@ -13,7 +11,6 @@ package doubleratchet
|
|||
// TODO: Think if to truncate an authentication tag to 128 bits.
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
|
@ -50,7 +47,7 @@ type state struct {
|
|||
PN uint
|
||||
|
||||
// Dictionary of skipped-over message keys, indexed by ratchet public key and message number.
|
||||
MkSkipped map[string][32]byte
|
||||
MkSkipped KeysStorage
|
||||
|
||||
// MaxSkip should be set high enough to tolerate routine lost or delayed messages,
|
||||
// but low enough that a malicious sender can't trigger excessive recipient computation.
|
||||
|
@ -70,7 +67,7 @@ func New(sharedKey [32]byte, opts ...Option) (State, error) {
|
|||
RK: sharedKey,
|
||||
CKs: sharedKey, // Populate CKs and CKr with sharedKey as per specification so that both
|
||||
CKr: sharedKey, // parties could both send and receive messages from the very beginning.
|
||||
MkSkipped: make(map[string][32]byte),
|
||||
MkSkipped: &KeysStorageInMemory{},
|
||||
MaxSkip: 1000,
|
||||
Crypto: DefaultCrypto{},
|
||||
}
|
||||
|
@ -113,6 +110,9 @@ func WithMaxSkip(n int) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: WithKeysStorage.
|
||||
// TODO: WithCrypto.
|
||||
|
||||
// RatchetEncrypt performs a symmetric-key ratchet step, then encrypts the message with
|
||||
// the resulting message key.
|
||||
func (s *state) RatchetEncrypt(plaintext []byte, ad AssociatedData) Message {
|
||||
|
@ -174,33 +174,27 @@ func (s *state) RatchetDecrypt(m Message, ad AssociatedData) ([]byte, error) {
|
|||
|
||||
// trySkippedMessageKeys tries to decrypt the message with a skipped message key.
|
||||
func (s *state) trySkippedMessageKeys(m Message, ad AssociatedData) ([]byte, error) {
|
||||
k := s.skippedKey(m.Header.DH[:], m.Header.N)
|
||||
if mk, ok := s.MkSkipped[k]; ok {
|
||||
if mk, ok := s.MkSkipped.Get(m.Header.DH, m.Header.N); ok {
|
||||
plaintext, err := s.Crypto.Decrypt(mk, m.Ciphertext, m.Header.EncodeWithAD(ad))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't decrypt message: %s", err)
|
||||
}
|
||||
delete(s.MkSkipped, k)
|
||||
s.MkSkipped.Delete(m.Header.DH, m.Header.N)
|
||||
return plaintext, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// skippedKey forms a key for a skipped message.
|
||||
func (s *state) skippedKey(dh []byte, n uint) string {
|
||||
return fmt.Sprintf("%s%d", hex.EncodeToString(dh), n)
|
||||
}
|
||||
|
||||
// skipMessageKeys skips message keys in the current receiving chain.
|
||||
func (s *state) skipMessageKeys(until uint) error {
|
||||
// until exceeds the number of messages in the receiving chain for no more than s.MaxSkip
|
||||
if s.Nr+s.MaxSkip < until {
|
||||
return fmt.Errorf("too many messages: %d", until-s.Nr)
|
||||
nSkipped := s.MkSkipped.Count(s.DHr)
|
||||
if until-s.Nr+nSkipped > s.MaxSkip {
|
||||
return fmt.Errorf("too many messages")
|
||||
}
|
||||
for s.Nr < until {
|
||||
var mk [32]byte
|
||||
s.CKr, mk = s.Crypto.KdfCK(s.CKr)
|
||||
s.MkSkipped[s.skippedKey(s.DHr[:], s.Nr)] = mk
|
||||
s.MkSkipped.Put(s.DHr, s.Nr, mk)
|
||||
s.Nr++
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -198,29 +198,29 @@ func TestState_RatchetDecrypt_CommunicationSkippedMessages(t *testing.T) {
|
|||
t.Run("skipped messages from alice", func(t *testing.T) {
|
||||
// Arrange.
|
||||
var (
|
||||
m1 = alice.RatchetEncrypt([]byte("hi"), nil)
|
||||
m2 = alice.RatchetEncrypt([]byte("bob"), nil)
|
||||
m3 = alice.RatchetEncrypt([]byte("how are you?"), nil)
|
||||
m4 = alice.RatchetEncrypt([]byte("still do cryptography?"), nil)
|
||||
m0 = alice.RatchetEncrypt([]byte("hi"), nil)
|
||||
m1 = alice.RatchetEncrypt([]byte("bob"), nil)
|
||||
m2 = alice.RatchetEncrypt([]byte("how are you?"), nil)
|
||||
m3 = alice.RatchetEncrypt([]byte("still do cryptography?"), nil)
|
||||
)
|
||||
|
||||
// Act and assert.
|
||||
d, err := bob.RatchetDecrypt(m2, nil) // Decrypted and skipped.
|
||||
d, err := bob.RatchetDecrypt(m1, nil) // Decrypted and skipped.
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, []byte("bob"), d)
|
||||
|
||||
_, err = bob.RatchetDecrypt(m4, nil) // Error: too many to skip.
|
||||
require.Nil(t, err)
|
||||
_, err = bob.RatchetDecrypt(m3, nil) // Error: too many to skip.
|
||||
require.NotNil(t, err)
|
||||
|
||||
d, err = bob.RatchetDecrypt(m3, nil) // Decrypted.
|
||||
d, err = bob.RatchetDecrypt(m2, nil) // Decrypted.
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, []byte("how are you?"), d)
|
||||
|
||||
d, err = bob.RatchetDecrypt(m4, nil) // Decrypted.
|
||||
d, err = bob.RatchetDecrypt(m3, nil) // Decrypted.
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, []byte("still do cryptography?"), d)
|
||||
|
||||
d, err = bob.RatchetDecrypt(m1, nil) // Decrypted.
|
||||
d, err = bob.RatchetDecrypt(m0, nil) // Decrypted.
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, []byte("hi"), d)
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue