Changes to state maps are now only applied at the end of RatchetDecrypt
This commit is contained in:
parent
fffab73a03
commit
6cbd7777ee
95
state.go
95
state.go
|
@ -141,45 +141,64 @@ func (s *state) RatchetEncrypt(plaintext []byte, ad AssociatedData) Message {
|
|||
|
||||
// RatchetDecrypt is called to decrypt messages.
|
||||
func (s *state) RatchetDecrypt(m Message, ad AssociatedData) ([]byte, error) {
|
||||
// All changes must be applied on a different state object, so that this state won't be modified nor left in a dirty state.
|
||||
var sc state = *s
|
||||
|
||||
// DEBUG
|
||||
//fmt.Printf("%+v\n\n", sc)
|
||||
//defer fmt.Printf("%+v\n\n", s)
|
||||
|
||||
// Is the messages one of the skipped?
|
||||
plaintext, err := sc.trySkippedMessageKeys(m, ad)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't try skipped message: %s", err)
|
||||
}
|
||||
if plaintext != nil {
|
||||
// Is the message one of the skipped?
|
||||
if mk, ok := s.MkSkipped.Get(m.Header.DH, m.Header.N); ok {
|
||||
plaintext, err := s.Crypto.Decrypt(mk, m.Ciphertext, m.Header.EncodeWithAD(ad))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't decrypt skipped message: %s", err)
|
||||
}
|
||||
s.MkSkipped.DeleteMk(m.Header.DH, m.Header.N)
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
var (
|
||||
// All changes must be applied on a different state object, so that this state won't be modified nor left in a dirty state.
|
||||
sc state = *s
|
||||
|
||||
skippedKeys1 []skippedKey
|
||||
skippedKeys2 []skippedKey
|
||||
err error
|
||||
)
|
||||
|
||||
// Is there a new ratchet key?
|
||||
isDHStepped := false
|
||||
if m.Header.DH != sc.DHr {
|
||||
if err := sc.skipMessageKeys(m.Header.PN); err != nil {
|
||||
return nil, fmt.Errorf("failed to skip previous chain message keys: %s", err)
|
||||
if skippedKeys1, err = sc.skipMessageKeys(m.Header.PN); err != nil {
|
||||
return nil, fmt.Errorf("can't skip previous chain message keys: %s", err)
|
||||
}
|
||||
if err := sc.dhRatchet(m.Header); err != nil {
|
||||
return nil, fmt.Errorf("failed to perform ratchet step: %s", err)
|
||||
if err = sc.dhRatchet(m.Header); err != nil {
|
||||
return nil, fmt.Errorf("can't perform ratchet step: %s", err)
|
||||
}
|
||||
isDHStepped = true
|
||||
}
|
||||
|
||||
// After all, apply changes on the current chain.
|
||||
if err := sc.skipMessageKeys(m.Header.N); err != nil {
|
||||
return nil, fmt.Errorf("failed to skip current chain message keys: %s", err)
|
||||
// After all, update the current chain.
|
||||
if skippedKeys2, err = sc.skipMessageKeys(m.Header.N); err != nil {
|
||||
return nil, fmt.Errorf("can't skip current chain message keys: %s", err)
|
||||
}
|
||||
var mk Key
|
||||
sc.CKr, mk = sc.Crypto.KdfCK(sc.CKr)
|
||||
sc.Nr++
|
||||
plaintext, err = sc.Crypto.Decrypt(mk, m.Ciphertext, m.Header.EncodeWithAD(ad))
|
||||
plaintext, err := sc.Crypto.Decrypt(mk, m.Ciphertext, m.Header.EncodeWithAD(ad))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt: %s", err)
|
||||
return nil, fmt.Errorf("can't decrypt: %s", err)
|
||||
}
|
||||
|
||||
// Apply changes.
|
||||
*s = sc
|
||||
if isDHStepped {
|
||||
s.PubKeys[s.Step] = s.DHr
|
||||
s.Step++
|
||||
if pubKey, ok := s.PubKeys[s.Step-s.MaxKeep]; ok {
|
||||
s.MkSkipped.DeletePk(pubKey)
|
||||
}
|
||||
}
|
||||
for _, skipped := range skippedKeys1 {
|
||||
s.MkSkipped.Put(skipped.dhr, skipped.nr, skipped.mk)
|
||||
}
|
||||
for _, skipped := range skippedKeys2 {
|
||||
s.MkSkipped.Put(skipped.dhr, skipped.nr, skipped.mk)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
@ -188,33 +207,26 @@ func (s *state) PublicKey() Key {
|
|||
return s.DHs.PublicKey()
|
||||
}
|
||||
|
||||
// trySkippedMessageKeys tries to decrypt the message with a skipped message key.
|
||||
func (s *state) trySkippedMessageKeys(m Message, ad AssociatedData) ([]byte, error) {
|
||||
if mk, ok := s.MkSkipped.Get(m.Header.DH, m.Header.N); ok {
|
||||
plaintext, err := s.Crypto.Decrypt(mk, m.Ciphertext, m.Header.EncodeWithAD(ad))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't decrypt message: %s", err)
|
||||
}
|
||||
s.MkSkipped.DeleteMk(m.Header.DH, m.Header.N)
|
||||
return plaintext, nil
|
||||
}
|
||||
return nil, nil
|
||||
type skippedKey struct {
|
||||
dhr Key
|
||||
nr uint
|
||||
mk Key
|
||||
}
|
||||
|
||||
// skipMessageKeys skips message keys in the current receiving chain.
|
||||
func (s *state) skipMessageKeys(until uint) error {
|
||||
func (s *state) skipMessageKeys(until uint) ([]skippedKey, error) {
|
||||
nSkipped := s.MkSkipped.Count(s.DHr)
|
||||
if until-s.Nr+nSkipped > s.MaxSkip {
|
||||
return fmt.Errorf("too many messages")
|
||||
return nil, fmt.Errorf("too many messages")
|
||||
}
|
||||
skipped := []skippedKey{}
|
||||
for s.Nr < until {
|
||||
var mk Key
|
||||
s.CKr, mk = s.Crypto.KdfCK(s.CKr)
|
||||
// FIXME: Changes to MkSkipped must not affect state.
|
||||
s.MkSkipped.Put(s.DHr, s.Nr, mk)
|
||||
skipped = append(skipped, skippedKey{s.DHr, s.Nr, mk})
|
||||
s.Nr++
|
||||
}
|
||||
return nil
|
||||
return skipped, nil
|
||||
}
|
||||
|
||||
// dhRatchet performs a single ratchet step.
|
||||
|
@ -232,12 +244,5 @@ func (s *state) dhRatchet(mh MessageHeader) error {
|
|||
}
|
||||
s.RK, s.CKs = s.Crypto.KdfRK(s.RK, s.Crypto.DH(s.DHs, s.DHr))
|
||||
|
||||
// FIXME: Changes to PubKeys must not affect the state object.
|
||||
s.PubKeys[s.Step] = s.DHr
|
||||
s.Step++
|
||||
if pubKey, ok := s.PubKeys[s.Step-s.MaxKeep]; ok {
|
||||
s.MkSkipped.DeletePk(pubKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue