diff --git a/noise_test.go b/noise_test.go index ecdbc89..e097124 100644 --- a/noise_test.go +++ b/noise_test.go @@ -464,3 +464,78 @@ func (NoiseSuite) TestHandshakeRollback(c *C) { expected, _ := hex.DecodeString("07a37cbc142093c8b755dc1b10e86cb426374ad16aa853ed0bdfc0b2b86d1c7c5e4dc9545d41b3280f4586a5481829e1e24ec5a0") c.Assert(msg, DeepEquals, expected) } + +func (NoiseSuite) TestRekey(c *C) { + rng := new(RandomInc) + + clientStaticKeypair := DH25519.GenerateKeypair(rng) + clientConfig := Config{} + clientConfig.CipherSuite = NewCipherSuite(DH25519, CipherChaChaPoly, HashBLAKE2b) + clientConfig.Random = rng + clientConfig.Pattern = HandshakeNN + clientConfig.Initiator = true + clientConfig.Prologue = []byte{0} + clientConfig.StaticKeypair = clientStaticKeypair + clientConfig.EphemeralKeypair = DH25519.GenerateKeypair(rng) + clientHs := NewHandshakeState(clientConfig) + + serverStaticKeypair := DH25519.GenerateKeypair(rng) + serverConfig := Config{} + serverConfig.CipherSuite = NewCipherSuite(DH25519, CipherChaChaPoly, HashBLAKE2b) + serverConfig.Random = rng + serverConfig.Pattern = HandshakeNN + serverConfig.Initiator = false + serverConfig.Prologue = []byte{0} + serverConfig.StaticKeypair = serverStaticKeypair + serverConfig.EphemeralKeypair = DH25519.GenerateKeypair(rng) + serverHs := NewHandshakeState(serverConfig) + + clientHsMsg, _, _ := clientHs.WriteMessage(nil, nil) + c.Assert(32, Equals, len(clientHsMsg)) + + serverHsResult, _, _, err := serverHs.ReadMessage(nil, clientHsMsg) + c.Assert(err, IsNil) + c.Assert(0, Equals, len(serverHsResult)) + + serverHsMsg, csR0, csR1 := serverHs.WriteMessage(nil, nil) + c.Assert(48, Equals, len(serverHsMsg)) + + clientHsResult, csI0, csI1, err := clientHs.ReadMessage(nil, serverHsMsg) + c.Assert(err, IsNil) + c.Assert(0, Equals, len(clientHsResult)) + + clientMessage := []byte("hello") + msg := csI0.Encrypt(nil, nil, clientMessage) + res, err := csR0.Decrypt(nil, nil, msg) + c.Assert(string(clientMessage), Equals, string(res)) + + oldK := csI0.k + csI0.Rekey() + c.Assert(oldK, Not(Equals), csI0.k) + csR0.Rekey() + + clientMessage = []byte("hello again") + msg = csI0.Encrypt(nil, nil, clientMessage) + res, err = csR0.Decrypt(nil, nil, msg) + c.Assert(string(clientMessage), Equals, string(res)) + + serverMessage := []byte("bye") + msg = csR1.Encrypt(nil, nil, serverMessage) + res, err = csI1.Decrypt(nil, nil, msg) + c.Assert(string(serverMessage), Equals, string(res)) + + csR1.Rekey() + csI1.Rekey() + + serverMessage = []byte("bye bye") + msg = csR1.Encrypt(nil, nil, serverMessage) + res, err = csI1.Decrypt(nil, nil, msg) + c.Assert(string(serverMessage), Equals, string(res)) + + // only rekey one side, test for failure + csR1.Rekey() + serverMessage = []byte("bye again") + msg = csR1.Encrypt(nil, nil, serverMessage) + res, err = csI1.Decrypt(nil, nil, msg) + c.Assert(string(serverMessage), Not(Equals), string(res)) +} diff --git a/state.go b/state.go index d3772e3..a4d3faf 100644 --- a/state.go +++ b/state.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "math" ) // A CipherState provides symmetric encryption and decryption after a successful @@ -61,6 +62,14 @@ func (s *CipherState) Cipher() Cipher { return s.c } +func (s *CipherState) Rekey() { + var zeros [32]byte + var out []byte + out = s.c.Encrypt(out, math.MaxUint64, []byte{}, zeros[:]) + copy(s.k[:], out[:32]) + s.c = s.cs.Cipher(s.k) +} + type symmetricState struct { CipherState hasK bool