diff --git a/box/box.go b/box/box.go index f69f249..b68860d 100644 --- a/box/box.go +++ b/box/box.go @@ -45,7 +45,6 @@ type Crypter struct { Key Key PeerKey Key ChainVar []byte - KDFNum uint8 scratch [64]byte cc CipherContext @@ -66,7 +65,7 @@ func (c *Crypter) EncryptBody(dst, plaintext, authtext []byte, padLen int) []byt return c.cc.Encrypt(dst, authtext, p) } -func (c *Crypter) EncryptBox(dst []byte, ephKey *Key, plaintext []byte, padLen int) ([]byte, error) { +func (c *Crypter) EncryptBox(dst []byte, ephKey *Key, plaintext []byte, padLen int, kdfNum uint8) ([]byte, error) { if len(c.ChainVar) == 0 { c.ChainVar = make([]byte, CVLen) } @@ -88,8 +87,8 @@ func (c *Crypter) EncryptBox(dst []byte, ephKey *Key, plaintext []byte, padLen i dh1 := c.Cipher.DH(ephKey.Private, c.PeerKey.Public) dh2 := c.Cipher.DH(c.Key.Private, c.PeerKey.Public) - cv1, cc1 := c.deriveKey(dh1, c.ChainVar) - cv2, cc2 := c.deriveKey(dh2, cv1) + cv1, cc1 := c.deriveKey(dh1, c.ChainVar, kdfNum) + cv2, cc2 := c.deriveKey(dh2, cv1, kdfNum+1) c.ChainVar = cv2 dst = append(dst, ephKey.Public...) @@ -115,14 +114,14 @@ func (c *Crypter) cipher(cc []byte) CipherContext { return c.cc } -func (c *Crypter) DecryptBox(ciphertext []byte) ([]byte, error) { +func (c *Crypter) DecryptBox(ciphertext []byte, kdfNum uint8) ([]byte, error) { if len(c.ChainVar) == 0 { c.ChainVar = make([]byte, CVLen) } ephPubKey := ciphertext[:c.Cipher.DHLen()] dh1 := c.Cipher.DH(c.Key.Private, ephPubKey) - cv1, cc1 := c.deriveKey(dh1, c.ChainVar) + cv1, cc1 := c.deriveKey(dh1, c.ChainVar, kdfNum) header := ciphertext[:(2*c.Cipher.DHLen())+c.Cipher.MACLen()] ciphertext = ciphertext[len(header):] @@ -137,7 +136,7 @@ func (c *Crypter) DecryptBox(ciphertext []byte) ([]byte, error) { } dh2 := c.Cipher.DH(c.Key.Private, senderPubKey) - cv2, cc2 := c.deriveKey(dh2, cv1) + cv2, cc2 := c.deriveKey(dh2, cv1, kdfNum+1) c.ChainVar = cv2 body, err := c.cipher(cc2).Decrypt(header, ciphertext) if err != nil { @@ -155,10 +154,9 @@ func (c *Crypter) DecryptBody(authtext, ciphertext []byte) ([]byte, error) { return c.cc.Decrypt(authtext, ciphertext) } -func (c *Crypter) deriveKey(dh, cv []byte) ([]byte, []byte) { - extra := append(c.Cipher.AppendName(c.scratch[:0]), c.KDFNum) +func (c *Crypter) deriveKey(dh, cv []byte, kdfNum uint8) ([]byte, []byte) { + extra := append(c.Cipher.AppendName(c.scratch[:0]), kdfNum) k := DeriveKey(dh, cv, extra, CVLen+c.Cipher.CCLen()) - c.KDFNum++ return k[:CVLen], k[CVLen:] } diff --git a/box/box_test.go b/box/box_test.go index 3a03db8..db58cc5 100644 --- a/box/box_test.go +++ b/box/box_test.go @@ -18,21 +18,21 @@ func (s *S) TestRoundtrip(c *C) { plain := []byte("yellow submarines") padLen := 2 - ciphertext, err := enc.EncryptBox(nil, nil, plain, padLen) + ciphertext, err := enc.EncryptBox(nil, nil, plain, padLen, 0) c.Assert(err, IsNil) expectedLen := len(plain) + padLen + (2 * Noise255.DHLen()) + (2 * Noise255.MACLen()) + 4 c.Assert(ciphertext, HasLen, expectedLen, Commentf("expected: %d", expectedLen)) - plaintext, err := dec.DecryptBox(ciphertext) + plaintext, err := dec.DecryptBox(ciphertext, 0) c.Assert(err, IsNil) c.Assert(plaintext, DeepEquals, plain) plain[0] = 'Y' - ciphertext, err = enc.EncryptBox(nil, nil, plain, 0) + ciphertext, err = enc.EncryptBox(nil, nil, plain, 0, 1) c.Assert(err, IsNil) - plaintext, err = dec.DecryptBox(ciphertext) + plaintext, err = dec.DecryptBox(ciphertext, 1) c.Assert(err, IsNil) c.Assert(plaintext, DeepEquals, plain) } @@ -63,6 +63,6 @@ func BenchmarkEncryptBox(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - enc.EncryptBox(nil, nil, []byte("yellow submarine"), 0) + enc.EncryptBox(nil, nil, []byte("yellow submarine"), 0, 0) } }