diff --git a/noise_test.go b/noise_test.go index ecdbc89..e097124 100644 --- a/noise_test.go +++ b/noise_test.go @@ -464,3 +464,78 @@ func (NoiseSuite) TestHandshakeRollback(c *C) { expected, _ := hex.DecodeString("07a37cbc142093c8b755dc1b10e86cb426374ad16aa853ed0bdfc0b2b86d1c7c5e4dc9545d41b3280f4586a5481829e1e24ec5a0") c.Assert(msg, DeepEquals, expected) } + +func (NoiseSuite) TestRekey(c *C) { + rng := new(RandomInc) + + clientStaticKeypair := DH25519.GenerateKeypair(rng) + clientConfig := Config{} + clientConfig.CipherSuite = NewCipherSuite(DH25519, CipherChaChaPoly, HashBLAKE2b) + clientConfig.Random = rng + clientConfig.Pattern = HandshakeNN + clientConfig.Initiator = true + clientConfig.Prologue = []byte{0} + clientConfig.StaticKeypair = clientStaticKeypair + clientConfig.EphemeralKeypair = DH25519.GenerateKeypair(rng) + clientHs := NewHandshakeState(clientConfig) + + serverStaticKeypair := DH25519.GenerateKeypair(rng) + serverConfig := Config{} + serverConfig.CipherSuite = NewCipherSuite(DH25519, CipherChaChaPoly, HashBLAKE2b) + serverConfig.Random = rng + serverConfig.Pattern = HandshakeNN + serverConfig.Initiator = false + serverConfig.Prologue = []byte{0} + serverConfig.StaticKeypair = serverStaticKeypair + serverConfig.EphemeralKeypair = DH25519.GenerateKeypair(rng) + serverHs := NewHandshakeState(serverConfig) + + clientHsMsg, _, _ := clientHs.WriteMessage(nil, nil) + c.Assert(32, Equals, len(clientHsMsg)) + + serverHsResult, _, _, err := serverHs.ReadMessage(nil, clientHsMsg) + c.Assert(err, IsNil) + c.Assert(0, Equals, len(serverHsResult)) + + serverHsMsg, csR0, csR1 := serverHs.WriteMessage(nil, nil) + c.Assert(48, Equals, len(serverHsMsg)) + + clientHsResult, csI0, csI1, err := clientHs.ReadMessage(nil, serverHsMsg) + c.Assert(err, IsNil) + c.Assert(0, Equals, len(clientHsResult)) + + clientMessage := []byte("hello") + msg := csI0.Encrypt(nil, nil, clientMessage) + res, err := csR0.Decrypt(nil, nil, msg) + c.Assert(string(clientMessage), Equals, string(res)) + + oldK := csI0.k + csI0.Rekey() + c.Assert(oldK, Not(Equals), csI0.k) + csR0.Rekey() + + clientMessage = []byte("hello again") + msg = csI0.Encrypt(nil, nil, clientMessage) + res, err = csR0.Decrypt(nil, nil, msg) + c.Assert(string(clientMessage), Equals, string(res)) + + serverMessage := []byte("bye") + msg = csR1.Encrypt(nil, nil, serverMessage) + res, err = csI1.Decrypt(nil, nil, msg) + c.Assert(string(serverMessage), Equals, string(res)) + + csR1.Rekey() + csI1.Rekey() + + serverMessage = []byte("bye bye") + msg = csR1.Encrypt(nil, nil, serverMessage) + res, err = csI1.Decrypt(nil, nil, msg) + c.Assert(string(serverMessage), Equals, string(res)) + + // only rekey one side, test for failure + csR1.Rekey() + serverMessage = []byte("bye again") + msg = csR1.Encrypt(nil, nil, serverMessage) + res, err = csI1.Decrypt(nil, nil, msg) + c.Assert(string(serverMessage), Not(Equals), string(res)) +} diff --git a/state.go b/state.go index 705d6af..a4d3faf 100644 --- a/state.go +++ b/state.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "math" ) // A CipherState provides symmetric encryption and decryption after a successful @@ -61,11 +62,19 @@ func (s *CipherState) Cipher() Cipher { return s.c } +func (s *CipherState) Rekey() { + var zeros [32]byte + var out []byte + out = s.c.Encrypt(out, math.MaxUint64, []byte{}, zeros[:]) + copy(s.k[:], out[:32]) + s.c = s.cs.Cipher(s.k) +} + type symmetricState struct { CipherState - hasK bool - ck []byte - h []byte + hasK bool + ck []byte + h []byte prevCK []byte prevH []byte @@ -199,7 +208,7 @@ type HandshakeState struct { e DHKey // local ephemeral keypair rs []byte // remote party's static public key re []byte // remote party's ephemeral public key - psk []byte // preshared key, maybe zero length + psk []byte // preshared key, maybe zero length messagePatterns [][]MessagePattern shouldWrite bool msgIdx int @@ -277,10 +286,10 @@ func NewHandshakeState(c Config) *HandshakeState { } pskModifier = fmt.Sprintf("psk%d", c.PresharedKeyPlacement) hs.messagePatterns = append([][]MessagePattern(nil), hs.messagePatterns...) - if (c.PresharedKeyPlacement == 0) { + if c.PresharedKeyPlacement == 0 { hs.messagePatterns[0] = append([]MessagePattern{MessagePatternPSK}, hs.messagePatterns[0]...) } else { - hs.messagePatterns[c.PresharedKeyPlacement - 1] = append(hs.messagePatterns[c.PresharedKeyPlacement - 1], MessagePatternPSK) + hs.messagePatterns[c.PresharedKeyPlacement-1] = append(hs.messagePatterns[c.PresharedKeyPlacement-1], MessagePatternPSK) } } hs.ss.InitializeSymmetric([]byte("Noise_" + c.Pattern.Name + pskModifier + "_" + string(hs.ss.cs.Name())))