diff --git a/noise_test.go b/noise_test.go index ea613c7..d57a4b4 100644 --- a/noise_test.go +++ b/noise_test.go @@ -539,3 +539,79 @@ func (NoiseSuite) TestRekey(c *C) { res, err = csI1.Decrypt(nil, nil, msg) c.Assert(string(serverMessage), Not(Equals), string(res)) } + +func (NoiseSuite) TestSetNonce(c *C) { + rng := new(RandomInc) + + clientStaticKeypair, _ := DH25519.GenerateKeypair(rng) + clientConfig := Config{} + clientConfig.CipherSuite = NewCipherSuite(DH25519, CipherChaChaPoly, HashBLAKE2b) + clientConfig.Random = rng + clientConfig.Pattern = HandshakeNN + clientConfig.Initiator = true + clientConfig.Prologue = []byte{0} + clientConfig.StaticKeypair = clientStaticKeypair + clientConfig.EphemeralKeypair, _ = DH25519.GenerateKeypair(rng) + clientHs, _ := NewHandshakeState(clientConfig) + + serverStaticKeypair, _ := DH25519.GenerateKeypair(rng) + serverConfig := Config{} + serverConfig.CipherSuite = NewCipherSuite(DH25519, CipherChaChaPoly, HashBLAKE2b) + serverConfig.Random = rng + serverConfig.Pattern = HandshakeNN + serverConfig.Initiator = false + serverConfig.Prologue = []byte{0} + serverConfig.StaticKeypair = serverStaticKeypair + serverConfig.EphemeralKeypair, _ = DH25519.GenerateKeypair(rng) + serverHs, _ := NewHandshakeState(serverConfig) + + clientHsMsg, _, _, _ := clientHs.WriteMessage(nil, nil) + c.Assert(32, Equals, len(clientHsMsg)) + + serverHsResult, _, _, err := serverHs.ReadMessage(nil, clientHsMsg) + c.Assert(err, IsNil) + c.Assert(0, Equals, len(serverHsResult)) + + serverHsMsg, csR0, csR1, _ := serverHs.WriteMessage(nil, nil) + c.Assert(48, Equals, len(serverHsMsg)) + + clientHsResult, csI0, csI1, err := clientHs.ReadMessage(nil, serverHsMsg) + c.Assert(err, IsNil) + c.Assert(0, Equals, len(clientHsResult)) + + clientMessage := []byte("hello") + msg := csI0.Encrypt(nil, nil, clientMessage) + res, err := csR0.Decrypt(nil, nil, msg) + c.Assert(string(clientMessage), Equals, string(res)) + + // dropped messages + csI0.Encrypt(nil, nil, clientMessage) + csI0.Encrypt(nil, nil, clientMessage) + csI0.Encrypt(nil, nil, clientMessage) + csI0.Encrypt(nil, nil, clientMessage) + // GetNonce/SetNonce + nonce := csI0.GetNonce() + csR0.SetNonce(nonce) + + clientMessage = []byte("hello again") + msg = csI0.Encrypt(nil, nil, clientMessage) + res, err = csR0.Decrypt(nil, nil, msg) + c.Assert(string(clientMessage), Equals, string(res)) + + serverMessage := []byte("bye") + msg = csR1.Encrypt(nil, nil, serverMessage) + res, err = csI1.Decrypt(nil, nil, msg) + c.Assert(string(serverMessage), Equals, string(res)) + + // dropped messages + csR1.Encrypt(nil, nil, clientMessage) + csR1.Encrypt(nil, nil, clientMessage) + // GetNonce/SetNonce + nonce = csR1.GetNonce() + csI1.SetNonce(nonce) + + serverMessage = []byte("bye bye") + msg = csR1.Encrypt(nil, nil, serverMessage) + res, err = csI1.Decrypt(nil, nil, msg) + c.Assert(string(serverMessage), Equals, string(res)) +} diff --git a/state.go b/state.go index 8b7b9ba..d6f4ff7 100644 --- a/state.go +++ b/state.go @@ -21,8 +21,23 @@ type CipherState struct { c Cipher k [32]byte n uint64 +} - invalid bool +// GetNonce is a nonce getter useful for out-of-order protocols where +// the nonce must be explicitly sent in addition to encrypted application data. +// It is to be used in conjunction with SetNonce(). More information is available +// in Section 11.4 (Out-of-order transport messages) of the Noise framework protocol. +func (s CipherState) GetNonce() uint64 { + return s.n +} + +// SetNonce is a helper for handling of out-of-order transport messages. +// When receiving an explicit nonce from an encrypted message, SetNonce +// can be used to set the decryption nonce to the received one. More information +// is available in Section 11.4 (Out-of-order transport messages) of the +// Noise framework protocol. +func (s *CipherState) SetNonce(nonce uint64) { + s.n = nonce } // Encrypt encrypts the plaintext and then appends the ciphertext and an @@ -30,9 +45,6 @@ type CipherState struct { // out. This method automatically increments the nonce after every call, so // messages must be decrypted in the same order. func (s *CipherState) Encrypt(out, ad, plaintext []byte) []byte { - if s.invalid { - panic("noise: CipherSuite has been copied, state is invalid") - } out = s.c.Encrypt(out, s.n, ad, plaintext) s.n++ return out @@ -43,25 +55,11 @@ func (s *CipherState) Encrypt(out, ad, plaintext []byte) []byte { // increments the nonce after every call, messages must be provided in the same // order that they were encrypted with no missing messages. func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) { - if s.invalid { - panic("noise: CipherSuite has been copied, state is invalid") - } out, err := s.c.Decrypt(out, s.n, ad, ciphertext) s.n++ return out, err } -// Cipher returns the low-level symmetric encryption primitive. It should only -// be used if nonces need to be managed manually, for example with a network -// protocol that can deliver out-of-order messages. This is dangerous, users -// must ensure that they are incrementing a nonce after every encrypt operation. -// After calling this method, it is an error to call Encrypt/Decrypt on the -// CipherState. -func (s *CipherState) Cipher() Cipher { - s.invalid = true - return s.c -} - func (s *CipherState) Rekey() { var zeros [32]byte var out []byte