179 lines
4.5 KiB
Go
179 lines
4.5 KiB
Go
package doubleratchet
|
|
|
|
// TODO: During each DH ratchet step a new ratchet key pair and sending chain are generated.
|
|
// As the sending chain is not needed right away, these steps could be deferred until the party
|
|
// is about to send a new message.
|
|
|
|
import (
|
|
"fmt"
|
|
)
|
|
|
|
// The double ratchet state.
|
|
type State struct {
|
|
Crypto Crypto
|
|
|
|
// DH Ratchet public key (the remote key).
|
|
DHr Key
|
|
|
|
// DH Ratchet key pair (the self ratchet key).
|
|
DHs DHPair
|
|
|
|
// Symmetric ratchet root chain.
|
|
RootCh kdfRootChain
|
|
|
|
// Symmetric ratchet sending and receiving chains.
|
|
SendCh, RecvCh kdfChain
|
|
|
|
// Number of messages in previous sending chain.
|
|
PN uint32
|
|
|
|
// Dictionary of skipped-over message keys, indexed by ratchet public key or header key
|
|
// and message number.
|
|
MkSkipped KeysStorage
|
|
|
|
// The maximum number of message keys that can be skipped in a single chain.
|
|
// WithMaxSkip should be set high enough to tolerate routine lost or delayed messages,
|
|
// but low enough that a malicious sender can't trigger excessive recipient computation.
|
|
MaxSkip uint
|
|
|
|
// Receiving header key and next header key. Only used for header encryption.
|
|
HKr, NHKr Key
|
|
|
|
// Sending header key and next header key. Only used for header encryption.
|
|
HKs, NHKs Key
|
|
|
|
// How long we keep messages keys, counted in number of messages received,
|
|
// for example if MaxKeep is 5 we only keep the last 5 messages keys, deleting everything n - 5.
|
|
MaxKeep uint
|
|
|
|
// Max number of message keys per session, older keys will be deleted in FIFO fashion
|
|
MaxMessageKeysPerSession int
|
|
|
|
// The number of the current ratchet step.
|
|
Step uint
|
|
|
|
// KeysCount the number of keys generated for decrypting
|
|
KeysCount uint
|
|
}
|
|
|
|
func DefaultState(sharedKey Key) State {
|
|
c := DefaultCrypto{}
|
|
|
|
return State{
|
|
DHs: dhPair{},
|
|
Crypto: c,
|
|
RootCh: kdfRootChain{CK: sharedKey, Crypto: c},
|
|
// Populate CKs and CKr with sharedKey so that both parties could send and receive
|
|
// messages from the very beginning.
|
|
SendCh: kdfChain{CK: sharedKey, Crypto: c},
|
|
RecvCh: kdfChain{CK: sharedKey, Crypto: c},
|
|
MkSkipped: &KeysStorageInMemory{},
|
|
MaxSkip: 1000,
|
|
MaxMessageKeysPerSession: 2000,
|
|
MaxKeep: 2000,
|
|
KeysCount: 0,
|
|
}
|
|
}
|
|
|
|
func (s *State) applyOptions(opts []option) error {
|
|
for i := range opts {
|
|
if err := opts[i](s); err != nil {
|
|
return fmt.Errorf("failed to apply option: %s", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func newState(sharedKey Key, opts ...option) (State, error) {
|
|
if sharedKey == nil {
|
|
return State{}, fmt.Errorf("sharedKey mustn't be empty")
|
|
}
|
|
|
|
s := DefaultState(sharedKey)
|
|
if err := s.applyOptions(opts); err != nil {
|
|
return State{}, err
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
// dhRatchet performs a single ratchet step.
|
|
func (s *State) dhRatchet(m MessageHeader) error {
|
|
s.PN = s.SendCh.N
|
|
s.DHr = m.DH
|
|
s.HKs = s.NHKs
|
|
s.HKr = s.NHKr
|
|
|
|
recvSecret, err := s.Crypto.DH(s.DHs, s.DHr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate dh recieve ratchet secret: %s", err)
|
|
}
|
|
s.RecvCh, s.NHKr = s.RootCh.step(recvSecret)
|
|
|
|
s.DHs, err = s.Crypto.GenerateDH()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate dh pair: %s", err)
|
|
}
|
|
|
|
sendSecret, err := s.Crypto.DH(s.DHs, s.DHr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate dh send ratchet secret: %s", err)
|
|
}
|
|
s.SendCh, s.NHKs = s.RootCh.step(sendSecret)
|
|
|
|
return nil
|
|
}
|
|
|
|
type skippedKey struct {
|
|
key Key
|
|
nr uint
|
|
mk Key
|
|
seq uint
|
|
}
|
|
|
|
// skipMessageKeys skips message keys in the current receiving chain.
|
|
func (s *State) skipMessageKeys(key Key, until uint) ([]skippedKey, error) {
|
|
if until < uint(s.RecvCh.N) {
|
|
return nil, fmt.Errorf("bad until: probably an out-of-order message that was deleted")
|
|
}
|
|
|
|
if uint(s.RecvCh.N)+s.MaxSkip < until {
|
|
return nil, fmt.Errorf("too many messages")
|
|
}
|
|
|
|
skipped := []skippedKey{}
|
|
for uint(s.RecvCh.N) < until {
|
|
mk := s.RecvCh.step()
|
|
skipped = append(skipped, skippedKey{
|
|
key: key,
|
|
nr: uint(s.RecvCh.N - 1),
|
|
mk: mk,
|
|
seq: s.KeysCount,
|
|
})
|
|
// Increment key count
|
|
s.KeysCount++
|
|
|
|
}
|
|
return skipped, nil
|
|
}
|
|
|
|
func (s *State) applyChanges(sc State, sessionID []byte, skipped []skippedKey) error {
|
|
*s = sc
|
|
for _, skipped := range skipped {
|
|
if err := s.MkSkipped.Put(sessionID, skipped.key, skipped.nr, skipped.mk, skipped.seq); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := s.MkSkipped.TruncateMks(sessionID, s.MaxMessageKeysPerSession); err != nil {
|
|
return err
|
|
}
|
|
if s.KeysCount >= s.MaxKeep {
|
|
if err := s.MkSkipped.DeleteOldMks(sessionID, s.KeysCount-s.MaxKeep); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|