Wrote tests for kdfChain and kdfRootChain
This commit is contained in:
parent
e71374e5da
commit
9b8364b1b0
37
chains.go
37
chains.go
|
@ -11,22 +11,7 @@ type KDFer interface {
|
|||
KdfCK(ck Key) (chainKey, msgKey Key)
|
||||
}
|
||||
|
||||
type rootChain struct {
|
||||
Crypto KDFer
|
||||
|
||||
// 32-byte chain key.
|
||||
CK Key
|
||||
}
|
||||
|
||||
func (c *rootChain) Step(kdfInput Key) (ch chain, nhk Key) {
|
||||
ch = chain{
|
||||
Crypto: c.Crypto,
|
||||
}
|
||||
c.CK, ch.CK, nhk = c.Crypto.KdfRK(c.CK, kdfInput)
|
||||
return ch, nhk
|
||||
}
|
||||
|
||||
type chain struct {
|
||||
type kdfChain struct {
|
||||
Crypto KDFer
|
||||
|
||||
// 32-byte chain key.
|
||||
|
@ -36,10 +21,26 @@ type chain struct {
|
|||
N uint32
|
||||
}
|
||||
|
||||
// Step performs chain step and returns message key.
|
||||
func (c *chain) Step() Key {
|
||||
// step performs symmetric ratchet step and returns a new message key.
|
||||
func (c *kdfChain) step() Key {
|
||||
var mk Key
|
||||
c.CK, mk = c.Crypto.KdfCK(c.CK)
|
||||
c.N++
|
||||
return mk
|
||||
}
|
||||
|
||||
type kdfRootChain struct {
|
||||
Crypto KDFer
|
||||
|
||||
// 32-byte kdfChain key.
|
||||
CK Key
|
||||
}
|
||||
|
||||
// step performs symmetric ratchet step and returns a new chain and new header key.
|
||||
func (c *kdfRootChain) step(kdfInput Key) (ch kdfChain, nhk Key) {
|
||||
ch = kdfChain{
|
||||
Crypto: c.Crypto,
|
||||
}
|
||||
c.CK, ch.CK, nhk = c.Crypto.KdfRK(c.CK, kdfInput)
|
||||
return ch, nhk
|
||||
}
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
package doubleratchet
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var chainKey = Key{0xeb, 0x8, 0x10, 0x7c, 0x33, 0x54, 0x0, 0x20, 0xe9, 0x4f, 0x6c, 0x84, 0xe4, 0x39, 0x50, 0x5a, 0x2f, 0x60, 0xbe, 0x81, 0xa, 0x78, 0x8b, 0xeb, 0x1e, 0x2c, 0x9, 0x8d, 0x4b, 0x4d, 0xc1, 0x40}
|
||||
|
||||
func TestChain_Step(t *testing.T) {
|
||||
// Arrange.
|
||||
ch := kdfChain{
|
||||
Crypto: DefaultCrypto{},
|
||||
CK: chainKey,
|
||||
}
|
||||
|
||||
// Act.
|
||||
mk := ch.step()
|
||||
|
||||
// Assert.
|
||||
require.EqualValues(t, 1, ch.N)
|
||||
require.NotEqual(t, chainKey, ch.CK)
|
||||
require.NotEqual(t, [32]byte{}, mk)
|
||||
}
|
||||
|
||||
func TestRootChain_Step(t *testing.T) {
|
||||
// Arrange.
|
||||
rch := kdfRootChain{
|
||||
Crypto: DefaultCrypto{},
|
||||
CK: chainKey,
|
||||
}
|
||||
|
||||
// Act.
|
||||
ch, nhk := rch.step(Key{0xe3, 0xbe, 0xb9, 0x4e, 0x70, 0x17, 0x37, 0xc, 0x1, 0x8f, 0xa9, 0x7e, 0xef, 0x4, 0xfb, 0x23, 0xac, 0xea, 0x28, 0xf7, 0xa9, 0x56, 0xcc, 0x1d, 0x46, 0xf3, 0xb5, 0x1d, 0x7d, 0x7d, 0x5e, 0x2c})
|
||||
|
||||
// Assert.
|
||||
require.NotEmpty(t, ch.Crypto)
|
||||
require.Empty(t, ch.N)
|
||||
require.NotEqual(t, [32]byte{}, ch)
|
||||
require.NotEqual(t, [32]byte{}, nhk)
|
||||
}
|
|
@ -35,7 +35,7 @@ func NewWithRK(sharedKey, remoteKey Key, opts ...option) (Session, error) {
|
|||
s := sI.(*session)
|
||||
s.DHr = remoteKey
|
||||
// FIXME: Where the header key goes?
|
||||
s.SendCh, _ = s.RootCh.Step(s.Crypto.DH(s.DHs, s.DHr))
|
||||
s.SendCh, _ = s.RootCh.step(s.Crypto.DH(s.DHs, s.DHr))
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,7 @@ func (s *session) RatchetEncrypt(plaintext, ad []byte) Message {
|
|||
N: s.SendCh.N,
|
||||
PN: s.PN,
|
||||
}
|
||||
mk = s.SendCh.Step()
|
||||
mk = s.SendCh.step()
|
||||
)
|
||||
adBuf = append(adBuf, ad...)
|
||||
ct := s.Crypto.Encrypt(mk, plaintext, append(adBuf, h.Encode()...))
|
||||
|
@ -93,7 +93,7 @@ func (s *session) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
|
|||
if skippedKeys2, err = sc.skipMessageKeys(sc.DHr, uint(m.Header.N)); err != nil {
|
||||
return nil, fmt.Errorf("can't skip current chain message keys: %s", err)
|
||||
}
|
||||
mk := sc.RecvCh.Step()
|
||||
mk := sc.RecvCh.step()
|
||||
plaintext, err := s.Crypto.Decrypt(mk, m.Ciphertext, append(ad, m.Header.Encode()...))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't decrypt: %s", err)
|
||||
|
|
|
@ -27,7 +27,7 @@ func (s *sessionHE) RatchetEncryptHE(plaintext, ad []byte) MessageHE {
|
|||
N: s.SendCh.N,
|
||||
PN: s.PN,
|
||||
}
|
||||
mk = s.SendCh.Step()
|
||||
mk = s.SendCh.step()
|
||||
hEnc = s.Crypto.Encrypt(s.HKs, h.Encode(), nil)
|
||||
)
|
||||
return MessageHE{
|
||||
|
@ -87,7 +87,7 @@ func (s *sessionHE) RatchetDecryptHE(m MessageHE, ad []byte) ([]byte, error) {
|
|||
if skippedKeys2, err = sc.skipMessageKeys(s.HKr, uint(h.N)); err != nil {
|
||||
return nil, fmt.Errorf("can't skip current chain message keys: %s", err)
|
||||
}
|
||||
mk := sc.RecvCh.Step()
|
||||
mk := sc.RecvCh.step()
|
||||
plaintext, err := s.Crypto.Decrypt(mk, m.Ciphertext, append(ad, m.Header...))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't decrypt: %s", err)
|
||||
|
|
16
state.go
16
state.go
|
@ -19,10 +19,10 @@ type state struct {
|
|||
DHs DHPair
|
||||
|
||||
// Symmetric ratchet root chain.
|
||||
RootCh rootChain
|
||||
RootCh kdfRootChain
|
||||
|
||||
// Symmetric ratchet sending and receiving chains.
|
||||
SendCh, RecvCh chain
|
||||
SendCh, RecvCh kdfChain
|
||||
|
||||
// Number of messages in previous sending chain.
|
||||
PN uint32
|
||||
|
@ -65,11 +65,11 @@ func newState(sharedKey Key, opts ...option) (state, error) {
|
|||
s := state{
|
||||
Crypto: c,
|
||||
DHs: dhs,
|
||||
RootCh: rootChain{CK: sharedKey, Crypto: c},
|
||||
RootCh: kdfRootChain{CK: sharedKey, Crypto: c},
|
||||
// Populate CKs and CKr with sharedKey as per specification so that both
|
||||
// parties could send and receive messages from the very beginning.
|
||||
SendCh: chain{CK: sharedKey, Crypto: c},
|
||||
RecvCh: chain{CK: sharedKey, Crypto: c},
|
||||
SendCh: kdfChain{CK: sharedKey, Crypto: c},
|
||||
RecvCh: kdfChain{CK: sharedKey, Crypto: c},
|
||||
MkSkipped: &KeysStorageInMemory{},
|
||||
MaxSkip: 1000,
|
||||
MaxKeep: 100,
|
||||
|
@ -91,13 +91,13 @@ func (s *state) dhRatchet(m MessageHeader) error {
|
|||
s.DHr = m.DH
|
||||
s.HKs = s.NHKs
|
||||
s.HKr = s.NHKr
|
||||
s.RecvCh, s.NHKr = s.RootCh.Step(s.Crypto.DH(s.DHs, s.DHr))
|
||||
s.RecvCh, s.NHKr = s.RootCh.step(s.Crypto.DH(s.DHs, s.DHr))
|
||||
var err error
|
||||
s.DHs, err = s.Crypto.GenerateDH()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate dh pair: %s", err)
|
||||
}
|
||||
s.SendCh, s.NHKs = s.RootCh.Step(s.Crypto.DH(s.DHs, s.DHr))
|
||||
s.SendCh, s.NHKs = s.RootCh.step(s.Crypto.DH(s.DHs, s.DHr))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -118,7 +118,7 @@ func (s *state) skipMessageKeys(key Key, until uint) ([]skippedKey, error) {
|
|||
}
|
||||
skipped := []skippedKey{}
|
||||
for uint(s.RecvCh.N) < until {
|
||||
mk := s.RecvCh.Step()
|
||||
mk := s.RecvCh.step()
|
||||
skipped = append(skipped, skippedKey{
|
||||
key: key,
|
||||
nr: uint(s.RecvCh.N - 1),
|
||||
|
|
Loading…
Reference in New Issue