diff --git a/chains.go b/chains.go index ab3beb0..b0c53c0 100644 --- a/chains.go +++ b/chains.go @@ -11,22 +11,7 @@ type KDFer interface { KdfCK(ck Key) (chainKey, msgKey Key) } -type rootChain struct { - Crypto KDFer - - // 32-byte chain key. - CK Key -} - -func (c *rootChain) Step(kdfInput Key) (ch chain, nhk Key) { - ch = chain{ - Crypto: c.Crypto, - } - c.CK, ch.CK, nhk = c.Crypto.KdfRK(c.CK, kdfInput) - return ch, nhk -} - -type chain struct { +type kdfChain struct { Crypto KDFer // 32-byte chain key. @@ -36,10 +21,26 @@ type chain struct { N uint32 } -// Step performs chain step and returns message key. -func (c *chain) Step() Key { +// step performs symmetric ratchet step and returns a new message key. +func (c *kdfChain) step() Key { var mk Key c.CK, mk = c.Crypto.KdfCK(c.CK) c.N++ return mk } + +type kdfRootChain struct { + Crypto KDFer + + // 32-byte kdfChain key. + CK Key +} + +// step performs symmetric ratchet step and returns a new chain and new header key. +func (c *kdfRootChain) step(kdfInput Key) (ch kdfChain, nhk Key) { + ch = kdfChain{ + Crypto: c.Crypto, + } + c.CK, ch.CK, nhk = c.Crypto.KdfRK(c.CK, kdfInput) + return ch, nhk +} diff --git a/chains_test.go b/chains_test.go new file mode 100644 index 0000000..90549bb --- /dev/null +++ b/chains_test.go @@ -0,0 +1,41 @@ +package doubleratchet + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +var chainKey = Key{0xeb, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40} + +func TestChain_Step(t *testing.T) { + // Arrange. + ch := kdfChain{ + Crypto: DefaultCrypto{}, + CK: chainKey, + } + + // Act. + mk := ch.step() + + // Assert. + require.EqualValues(t, 1, ch.N) + require.NotEqual(t, chainKey, ch.CK) + require.NotEqual(t, [32]byte{}, mk) +} + +func TestRootChain_Step(t *testing.T) { + // Arrange. + rch := kdfRootChain{ + Crypto: DefaultCrypto{}, + CK: chainKey, + } + + // Act. + ch, nhk := rch.step(Key{0xe3, 0xbe, 0xb9, 0x4e, 0x70, 0x17, 0x37, 0xc, 0x1, 0x8f, 0xa9, 0x7e, 0xef, 0x4, 0xfb, 0x23, 0xac, 0xea, 0x28, 0xf7, 0xa9, 0x56, 0xcc, 0x1d, 0x46, 0xf3, 0xb5, 0x1d, 0x7d, 0x7d, 0x5e, 0x2c}) + + // Assert. + require.NotEmpty(t, ch.Crypto) + require.Empty(t, ch.N) + require.NotEqual(t, [32]byte{}, ch) + require.NotEqual(t, [32]byte{}, nhk) +} diff --git a/session.go b/session.go index 9b56436..3d0a073 100644 --- a/session.go +++ b/session.go @@ -35,7 +35,7 @@ func NewWithRK(sharedKey, remoteKey Key, opts ...option) (Session, error) { s := sI.(*session) s.DHr = remoteKey // FIXME: Where the header key goes? - s.SendCh, _ = s.RootCh.Step(s.Crypto.DH(s.DHs, s.DHr)) + s.SendCh, _ = s.RootCh.step(s.Crypto.DH(s.DHs, s.DHr)) return s, nil } @@ -49,7 +49,7 @@ func (s *session) RatchetEncrypt(plaintext, ad []byte) Message { N: s.SendCh.N, PN: s.PN, } - mk = s.SendCh.Step() + mk = s.SendCh.step() ) adBuf = append(adBuf, ad...) ct := s.Crypto.Encrypt(mk, plaintext, append(adBuf, h.Encode()...)) @@ -93,7 +93,7 @@ func (s *session) RatchetDecrypt(m Message, ad []byte) ([]byte, error) { if skippedKeys2, err = sc.skipMessageKeys(sc.DHr, uint(m.Header.N)); err != nil { return nil, fmt.Errorf("can't skip current chain message keys: %s", err) } - mk := sc.RecvCh.Step() + mk := sc.RecvCh.step() plaintext, err := s.Crypto.Decrypt(mk, m.Ciphertext, append(ad, m.Header.Encode()...)) if err != nil { return nil, fmt.Errorf("can't decrypt: %s", err) diff --git a/session_he.go b/session_he.go index 89bb2e8..0cdc85e 100644 --- a/session_he.go +++ b/session_he.go @@ -27,7 +27,7 @@ func (s *sessionHE) RatchetEncryptHE(plaintext, ad []byte) MessageHE { N: s.SendCh.N, PN: s.PN, } - mk = s.SendCh.Step() + mk = s.SendCh.step() hEnc = s.Crypto.Encrypt(s.HKs, h.Encode(), nil) ) return MessageHE{ @@ -87,7 +87,7 @@ func (s *sessionHE) RatchetDecryptHE(m MessageHE, ad []byte) ([]byte, error) { if skippedKeys2, err = sc.skipMessageKeys(s.HKr, uint(h.N)); err != nil { return nil, fmt.Errorf("can't skip current chain message keys: %s", err) } - mk := sc.RecvCh.Step() + mk := sc.RecvCh.step() plaintext, err := s.Crypto.Decrypt(mk, m.Ciphertext, append(ad, m.Header...)) if err != nil { return nil, fmt.Errorf("can't decrypt: %s", err) diff --git a/state.go b/state.go index a3d05ad..7dfaf1f 100644 --- a/state.go +++ b/state.go @@ -19,10 +19,10 @@ type state struct { DHs DHPair // Symmetric ratchet root chain. - RootCh rootChain + RootCh kdfRootChain // Symmetric ratchet sending and receiving chains. - SendCh, RecvCh chain + SendCh, RecvCh kdfChain // Number of messages in previous sending chain. PN uint32 @@ -65,11 +65,11 @@ func newState(sharedKey Key, opts ...option) (state, error) { s := state{ Crypto: c, DHs: dhs, - RootCh: rootChain{CK: sharedKey, Crypto: c}, + RootCh: kdfRootChain{CK: sharedKey, Crypto: c}, // Populate CKs and CKr with sharedKey as per specification so that both // parties could send and receive messages from the very beginning. - SendCh: chain{CK: sharedKey, Crypto: c}, - RecvCh: chain{CK: sharedKey, Crypto: c}, + SendCh: kdfChain{CK: sharedKey, Crypto: c}, + RecvCh: kdfChain{CK: sharedKey, Crypto: c}, MkSkipped: &KeysStorageInMemory{}, MaxSkip: 1000, MaxKeep: 100, @@ -91,13 +91,13 @@ func (s *state) dhRatchet(m MessageHeader) error { s.DHr = m.DH s.HKs = s.NHKs s.HKr = s.NHKr - s.RecvCh, s.NHKr = s.RootCh.Step(s.Crypto.DH(s.DHs, s.DHr)) + s.RecvCh, s.NHKr = s.RootCh.step(s.Crypto.DH(s.DHs, s.DHr)) var err error s.DHs, err = s.Crypto.GenerateDH() if err != nil { return fmt.Errorf("failed to generate dh pair: %s", err) } - s.SendCh, s.NHKs = s.RootCh.Step(s.Crypto.DH(s.DHs, s.DHr)) + s.SendCh, s.NHKs = s.RootCh.step(s.Crypto.DH(s.DHs, s.DHr)) return nil } @@ -118,7 +118,7 @@ func (s *state) skipMessageKeys(key Key, until uint) ([]skippedKey, error) { } skipped := []skippedKey{} for uint(s.RecvCh.N) < until { - mk := s.RecvCh.Step() + mk := s.RecvCh.step() skipped = append(skipped, skippedKey{ key: key, nr: uint(s.RecvCh.N - 1),