diff --git a/noise_test.go b/noise_test.go index 214cabe..783a9eb 100644 --- a/noise_test.go +++ b/noise_test.go @@ -227,19 +227,22 @@ func (NoiseSuite) TestXXRoundtrip(c *C) { c.Assert(string(res), Equals, payload) // transport message I -> R - msg = csI0.Encrypt(nil, nil, []byte("wubba")) + msg, err = csI0.Encrypt(nil, nil, []byte("wubba")) + c.Assert(err, IsNil) res, err = csR0.Decrypt(nil, nil, msg) c.Assert(err, IsNil) c.Assert(string(res), Equals, "wubba") // transport message I -> R again - msg = csI0.Encrypt(nil, nil, []byte("aleph")) + msg, err = csI0.Encrypt(nil, nil, []byte("aleph")) + c.Assert(err, IsNil) res, err = csR0.Decrypt(nil, nil, msg) c.Assert(err, IsNil) c.Assert(string(res), Equals, "aleph") // transport message R <- I - msg = csR1.Encrypt(nil, nil, []byte("worri")) + msg, err = csR1.Encrypt(nil, nil, []byte("worri")) + c.Assert(err, IsNil) res, err = csI1.Decrypt(nil, nil, msg) c.Assert(err, IsNil) c.Assert(string(res), Equals, "worri") @@ -280,13 +283,15 @@ func (NoiseSuite) Test_NNpsk0_Roundtrip(c *C) { c.Assert(res, HasLen, 0) // transport I -> R - msg = csI0.Encrypt(nil, nil, []byte("foo")) + msg, err = csI0.Encrypt(nil, nil, []byte("foo")) + c.Assert(err, IsNil) res, err = csR0.Decrypt(nil, nil, msg) c.Assert(err, IsNil) c.Assert(string(res), Equals, "foo") // transport R -> I - msg = csR1.Encrypt(nil, nil, []byte("bar")) + msg, err = csR1.Encrypt(nil, nil, []byte("bar")) + c.Assert(err, IsNil) res, err = csI1.Decrypt(nil, nil, msg) c.Assert(err, IsNil) c.Assert(string(res), Equals, "bar") @@ -552,7 +557,8 @@ func (NoiseSuite) TestRekey(c *C) { c.Assert(0, Equals, len(clientHsResult)) clientMessage := []byte("hello") - msg := csI0.Encrypt(nil, nil, clientMessage) + msg, err := csI0.Encrypt(nil, nil, clientMessage) + c.Assert(err, IsNil) res, err := csR0.Decrypt(nil, nil, msg) c.Assert(err, IsNil) c.Assert(string(clientMessage), Equals, string(res)) @@ -563,13 +569,15 @@ func (NoiseSuite) TestRekey(c *C) { csR0.Rekey() clientMessage = []byte("hello again") - msg = csI0.Encrypt(nil, nil, clientMessage) + msg, err = csI0.Encrypt(nil, nil, clientMessage) + c.Assert(err, IsNil) res, err = csR0.Decrypt(nil, nil, msg) c.Assert(err, IsNil) c.Assert(string(clientMessage), Equals, string(res)) serverMessage := []byte("bye") - msg = csR1.Encrypt(nil, nil, serverMessage) + msg, err = csR1.Encrypt(nil, nil, serverMessage) + c.Assert(err, IsNil) res, err = csI1.Decrypt(nil, nil, msg) c.Assert(err, IsNil) c.Assert(string(serverMessage), Equals, string(res)) @@ -578,7 +586,8 @@ func (NoiseSuite) TestRekey(c *C) { csI1.Rekey() serverMessage = []byte("bye bye") - msg = csR1.Encrypt(nil, nil, serverMessage) + msg, err = csR1.Encrypt(nil, nil, serverMessage) + c.Assert(err, IsNil) res, err = csI1.Decrypt(nil, nil, msg) c.Assert(err, IsNil) c.Assert(string(serverMessage), Equals, string(res)) @@ -586,7 +595,8 @@ func (NoiseSuite) TestRekey(c *C) { // only rekey one side, test for failure csR1.Rekey() serverMessage = []byte("bye again") - msg = csR1.Encrypt(nil, nil, serverMessage) + msg, err = csR1.Encrypt(nil, nil, serverMessage) + c.Assert(err, IsNil) res, err = csI1.Decrypt(nil, nil, msg) c.Assert(err, NotNil) c.Assert(string(serverMessage), Not(Equals), string(res)) diff --git a/state.go b/state.go index 418522e..4153e68 100644 --- a/state.go +++ b/state.go @@ -25,17 +25,19 @@ type CipherState struct { invalid bool } +var ErrCipherSuiteCopied = errors.New("noise: CipherSuite has been copied, state is invalid") + // Encrypt encrypts the plaintext and then appends the ciphertext and an // authentication tag across the ciphertext and optional authenticated data to // out. This method automatically increments the nonce after every call, so // messages must be decrypted in the same order. -func (s *CipherState) Encrypt(out, ad, plaintext []byte) []byte { +func (s *CipherState) Encrypt(out, ad, plaintext []byte) ([]byte, error) { if s.invalid { - panic("noise: CipherSuite has been copied, state is invalid") + return nil, ErrCipherSuiteCopied } out = s.c.Encrypt(out, s.n, ad, plaintext) s.n++ - return out + return out, nil } // Decrypt checks the authenticity of the ciphertext and authenticated data and @@ -44,7 +46,7 @@ func (s *CipherState) Encrypt(out, ad, plaintext []byte) []byte { // order that they were encrypted with no missing messages. func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) { if s.invalid { - panic("noise: CipherSuite has been copied, state is invalid") + return nil, ErrCipherSuiteCopied } out, err := s.c.Decrypt(out, s.n, ad, ciphertext) s.n++ @@ -120,14 +122,17 @@ func (s *symmetricState) MixKeyAndHash(data []byte) { s.hasK = true } -func (s *symmetricState) EncryptAndHash(out, plaintext []byte) []byte { +func (s *symmetricState) EncryptAndHash(out, plaintext []byte) ([]byte, error) { if !s.hasK { s.MixHash(plaintext) - return append(out, plaintext...) + return append(out, plaintext...), nil + } + ciphertext, err := s.Encrypt(out, s.h, plaintext) + if err != nil { + return nil, err } - ciphertext := s.Encrypt(out, s.h, plaintext) s.MixHash(ciphertext[len(out):]) - return ciphertext + return ciphertext, nil } func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) { @@ -340,6 +345,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState return nil, nil, nil, errors.New("noise: message is too long") } + var err error for _, msg := range s.messagePatterns[s.msgIdx] { switch msg { case MessagePatternE: @@ -357,7 +363,10 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState if len(s.s.Public) == 0 { return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil") } - out = s.ss.EncryptAndHash(out, s.s.Public) + out, err = s.ss.EncryptAndHash(out, s.s.Public) + if err != nil { + return nil, nil, nil, err + } case MessagePatternDHEE: dh, err := s.ss.cs.DH(s.e.Private, s.re) if err != nil { @@ -404,7 +413,10 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState } s.shouldWrite = false s.msgIdx++ - out = s.ss.EncryptAndHash(out, payload) + out, err = s.ss.EncryptAndHash(out, payload) + if err != nil { + return nil, nil, nil, err + } if s.msgIdx >= len(s.messagePatterns) { cs1, cs2 := s.ss.Split() diff --git a/vector_test.go b/vector_test.go index 6d17a32..7193b82 100644 --- a/vector_test.go +++ b/vector_test.go @@ -199,7 +199,8 @@ func (NoiseSuite) TestVectors(c *C) { if (i-len(configI.Pattern.Messages))%2 != 0 { enc, dec = csW1, csR1 } - encrypted := enc.Encrypt(nil, nil, payload) + encrypted, err := enc.Encrypt(nil, nil, payload) + c.Assert(err, IsNil) c.Assert(fmt.Sprintf("%x", encrypted), Equals, string(splitLine[1])) decrypted, err := dec.Decrypt(nil, nil, encrypted) c.Assert(err, IsNil) diff --git a/vectorgen/vectorgen.go b/vectorgen/vectorgen.go index 5f5b6b1..f835155 100644 --- a/vectorgen/vectorgen.go +++ b/vectorgen/vectorgen.go @@ -178,7 +178,9 @@ func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, pskPlacem payload0 := []byte("yellowsubmarine") payload1 := []byte("submarineyellow") fmt.Fprintf(out, "msg_%d_payload=%x\n", len(h.Messages), payload0) - fmt.Fprintf(out, "msg_%d_ciphertext=%x\n", len(h.Messages), cs0.Encrypt(nil, nil, payload0)) + ciphertext0, _ := cs0.Encrypt(nil, nil, payload0) + fmt.Fprintf(out, "msg_%d_ciphertext=%x\n", len(h.Messages), ciphertext0) fmt.Fprintf(out, "msg_%d_payload=%x\n", len(h.Messages)+1, payload1) - fmt.Fprintf(out, "msg_%d_ciphertext=%x\n", len(h.Messages)+1, cs1.Encrypt(nil, nil, payload1)) + ciphertext1, _ := cs1.Encrypt(nil, nil, payload1) + fmt.Fprintf(out, "msg_%d_ciphertext=%x\n", len(h.Messages)+1, ciphertext1) }