mirror of
https://github.com/logos-messaging/noise.git
synced 2026-01-02 14:13:07 +00:00
Return error from CipherSuite.Encrypt
This commit is contained in:
parent
fc2bb37e28
commit
0d4f803fc7
@ -227,19 +227,22 @@ func (NoiseSuite) TestXXRoundtrip(c *C) {
|
|||||||
c.Assert(string(res), Equals, payload)
|
c.Assert(string(res), Equals, payload)
|
||||||
|
|
||||||
// transport message I -> R
|
// transport message I -> R
|
||||||
msg = csI0.Encrypt(nil, nil, []byte("wubba"))
|
msg, err = csI0.Encrypt(nil, nil, []byte("wubba"))
|
||||||
|
c.Assert(err, IsNil)
|
||||||
res, err = csR0.Decrypt(nil, nil, msg)
|
res, err = csR0.Decrypt(nil, nil, msg)
|
||||||
c.Assert(err, IsNil)
|
c.Assert(err, IsNil)
|
||||||
c.Assert(string(res), Equals, "wubba")
|
c.Assert(string(res), Equals, "wubba")
|
||||||
|
|
||||||
// transport message I -> R again
|
// transport message I -> R again
|
||||||
msg = csI0.Encrypt(nil, nil, []byte("aleph"))
|
msg, err = csI0.Encrypt(nil, nil, []byte("aleph"))
|
||||||
|
c.Assert(err, IsNil)
|
||||||
res, err = csR0.Decrypt(nil, nil, msg)
|
res, err = csR0.Decrypt(nil, nil, msg)
|
||||||
c.Assert(err, IsNil)
|
c.Assert(err, IsNil)
|
||||||
c.Assert(string(res), Equals, "aleph")
|
c.Assert(string(res), Equals, "aleph")
|
||||||
|
|
||||||
// transport message R <- I
|
// transport message R <- I
|
||||||
msg = csR1.Encrypt(nil, nil, []byte("worri"))
|
msg, err = csR1.Encrypt(nil, nil, []byte("worri"))
|
||||||
|
c.Assert(err, IsNil)
|
||||||
res, err = csI1.Decrypt(nil, nil, msg)
|
res, err = csI1.Decrypt(nil, nil, msg)
|
||||||
c.Assert(err, IsNil)
|
c.Assert(err, IsNil)
|
||||||
c.Assert(string(res), Equals, "worri")
|
c.Assert(string(res), Equals, "worri")
|
||||||
@ -280,13 +283,15 @@ func (NoiseSuite) Test_NNpsk0_Roundtrip(c *C) {
|
|||||||
c.Assert(res, HasLen, 0)
|
c.Assert(res, HasLen, 0)
|
||||||
|
|
||||||
// transport I -> R
|
// transport I -> R
|
||||||
msg = csI0.Encrypt(nil, nil, []byte("foo"))
|
msg, err = csI0.Encrypt(nil, nil, []byte("foo"))
|
||||||
|
c.Assert(err, IsNil)
|
||||||
res, err = csR0.Decrypt(nil, nil, msg)
|
res, err = csR0.Decrypt(nil, nil, msg)
|
||||||
c.Assert(err, IsNil)
|
c.Assert(err, IsNil)
|
||||||
c.Assert(string(res), Equals, "foo")
|
c.Assert(string(res), Equals, "foo")
|
||||||
|
|
||||||
// transport R -> I
|
// transport R -> I
|
||||||
msg = csR1.Encrypt(nil, nil, []byte("bar"))
|
msg, err = csR1.Encrypt(nil, nil, []byte("bar"))
|
||||||
|
c.Assert(err, IsNil)
|
||||||
res, err = csI1.Decrypt(nil, nil, msg)
|
res, err = csI1.Decrypt(nil, nil, msg)
|
||||||
c.Assert(err, IsNil)
|
c.Assert(err, IsNil)
|
||||||
c.Assert(string(res), Equals, "bar")
|
c.Assert(string(res), Equals, "bar")
|
||||||
@ -552,7 +557,8 @@ func (NoiseSuite) TestRekey(c *C) {
|
|||||||
c.Assert(0, Equals, len(clientHsResult))
|
c.Assert(0, Equals, len(clientHsResult))
|
||||||
|
|
||||||
clientMessage := []byte("hello")
|
clientMessage := []byte("hello")
|
||||||
msg := csI0.Encrypt(nil, nil, clientMessage)
|
msg, err := csI0.Encrypt(nil, nil, clientMessage)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
res, err := csR0.Decrypt(nil, nil, msg)
|
res, err := csR0.Decrypt(nil, nil, msg)
|
||||||
c.Assert(err, IsNil)
|
c.Assert(err, IsNil)
|
||||||
c.Assert(string(clientMessage), Equals, string(res))
|
c.Assert(string(clientMessage), Equals, string(res))
|
||||||
@ -563,13 +569,15 @@ func (NoiseSuite) TestRekey(c *C) {
|
|||||||
csR0.Rekey()
|
csR0.Rekey()
|
||||||
|
|
||||||
clientMessage = []byte("hello again")
|
clientMessage = []byte("hello again")
|
||||||
msg = csI0.Encrypt(nil, nil, clientMessage)
|
msg, err = csI0.Encrypt(nil, nil, clientMessage)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
res, err = csR0.Decrypt(nil, nil, msg)
|
res, err = csR0.Decrypt(nil, nil, msg)
|
||||||
c.Assert(err, IsNil)
|
c.Assert(err, IsNil)
|
||||||
c.Assert(string(clientMessage), Equals, string(res))
|
c.Assert(string(clientMessage), Equals, string(res))
|
||||||
|
|
||||||
serverMessage := []byte("bye")
|
serverMessage := []byte("bye")
|
||||||
msg = csR1.Encrypt(nil, nil, serverMessage)
|
msg, err = csR1.Encrypt(nil, nil, serverMessage)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
res, err = csI1.Decrypt(nil, nil, msg)
|
res, err = csI1.Decrypt(nil, nil, msg)
|
||||||
c.Assert(err, IsNil)
|
c.Assert(err, IsNil)
|
||||||
c.Assert(string(serverMessage), Equals, string(res))
|
c.Assert(string(serverMessage), Equals, string(res))
|
||||||
@ -578,7 +586,8 @@ func (NoiseSuite) TestRekey(c *C) {
|
|||||||
csI1.Rekey()
|
csI1.Rekey()
|
||||||
|
|
||||||
serverMessage = []byte("bye bye")
|
serverMessage = []byte("bye bye")
|
||||||
msg = csR1.Encrypt(nil, nil, serverMessage)
|
msg, err = csR1.Encrypt(nil, nil, serverMessage)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
res, err = csI1.Decrypt(nil, nil, msg)
|
res, err = csI1.Decrypt(nil, nil, msg)
|
||||||
c.Assert(err, IsNil)
|
c.Assert(err, IsNil)
|
||||||
c.Assert(string(serverMessage), Equals, string(res))
|
c.Assert(string(serverMessage), Equals, string(res))
|
||||||
@ -586,7 +595,8 @@ func (NoiseSuite) TestRekey(c *C) {
|
|||||||
// only rekey one side, test for failure
|
// only rekey one side, test for failure
|
||||||
csR1.Rekey()
|
csR1.Rekey()
|
||||||
serverMessage = []byte("bye again")
|
serverMessage = []byte("bye again")
|
||||||
msg = csR1.Encrypt(nil, nil, serverMessage)
|
msg, err = csR1.Encrypt(nil, nil, serverMessage)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
res, err = csI1.Decrypt(nil, nil, msg)
|
res, err = csI1.Decrypt(nil, nil, msg)
|
||||||
c.Assert(err, NotNil)
|
c.Assert(err, NotNil)
|
||||||
c.Assert(string(serverMessage), Not(Equals), string(res))
|
c.Assert(string(serverMessage), Not(Equals), string(res))
|
||||||
|
|||||||
32
state.go
32
state.go
@ -25,17 +25,19 @@ type CipherState struct {
|
|||||||
invalid bool
|
invalid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrCipherSuiteCopied = errors.New("noise: CipherSuite has been copied, state is invalid")
|
||||||
|
|
||||||
// Encrypt encrypts the plaintext and then appends the ciphertext and an
|
// Encrypt encrypts the plaintext and then appends the ciphertext and an
|
||||||
// authentication tag across the ciphertext and optional authenticated data to
|
// authentication tag across the ciphertext and optional authenticated data to
|
||||||
// out. This method automatically increments the nonce after every call, so
|
// out. This method automatically increments the nonce after every call, so
|
||||||
// messages must be decrypted in the same order.
|
// messages must be decrypted in the same order.
|
||||||
func (s *CipherState) Encrypt(out, ad, plaintext []byte) []byte {
|
func (s *CipherState) Encrypt(out, ad, plaintext []byte) ([]byte, error) {
|
||||||
if s.invalid {
|
if s.invalid {
|
||||||
panic("noise: CipherSuite has been copied, state is invalid")
|
return nil, ErrCipherSuiteCopied
|
||||||
}
|
}
|
||||||
out = s.c.Encrypt(out, s.n, ad, plaintext)
|
out = s.c.Encrypt(out, s.n, ad, plaintext)
|
||||||
s.n++
|
s.n++
|
||||||
return out
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrypt checks the authenticity of the ciphertext and authenticated data and
|
// Decrypt checks the authenticity of the ciphertext and authenticated data and
|
||||||
@ -44,7 +46,7 @@ func (s *CipherState) Encrypt(out, ad, plaintext []byte) []byte {
|
|||||||
// order that they were encrypted with no missing messages.
|
// order that they were encrypted with no missing messages.
|
||||||
func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) {
|
func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) {
|
||||||
if s.invalid {
|
if s.invalid {
|
||||||
panic("noise: CipherSuite has been copied, state is invalid")
|
return nil, ErrCipherSuiteCopied
|
||||||
}
|
}
|
||||||
out, err := s.c.Decrypt(out, s.n, ad, ciphertext)
|
out, err := s.c.Decrypt(out, s.n, ad, ciphertext)
|
||||||
s.n++
|
s.n++
|
||||||
@ -120,14 +122,17 @@ func (s *symmetricState) MixKeyAndHash(data []byte) {
|
|||||||
s.hasK = true
|
s.hasK = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *symmetricState) EncryptAndHash(out, plaintext []byte) []byte {
|
func (s *symmetricState) EncryptAndHash(out, plaintext []byte) ([]byte, error) {
|
||||||
if !s.hasK {
|
if !s.hasK {
|
||||||
s.MixHash(plaintext)
|
s.MixHash(plaintext)
|
||||||
return append(out, plaintext...)
|
return append(out, plaintext...), nil
|
||||||
|
}
|
||||||
|
ciphertext, err := s.Encrypt(out, s.h, plaintext)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
ciphertext := s.Encrypt(out, s.h, plaintext)
|
|
||||||
s.MixHash(ciphertext[len(out):])
|
s.MixHash(ciphertext[len(out):])
|
||||||
return ciphertext
|
return ciphertext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) {
|
func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) {
|
||||||
@ -340,6 +345,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
|
|||||||
return nil, nil, nil, errors.New("noise: message is too long")
|
return nil, nil, nil, errors.New("noise: message is too long")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
for _, msg := range s.messagePatterns[s.msgIdx] {
|
for _, msg := range s.messagePatterns[s.msgIdx] {
|
||||||
switch msg {
|
switch msg {
|
||||||
case MessagePatternE:
|
case MessagePatternE:
|
||||||
@ -357,7 +363,10 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
|
|||||||
if len(s.s.Public) == 0 {
|
if len(s.s.Public) == 0 {
|
||||||
return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
|
return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
|
||||||
}
|
}
|
||||||
out = s.ss.EncryptAndHash(out, s.s.Public)
|
out, err = s.ss.EncryptAndHash(out, s.s.Public)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
case MessagePatternDHEE:
|
case MessagePatternDHEE:
|
||||||
dh, err := s.ss.cs.DH(s.e.Private, s.re)
|
dh, err := s.ss.cs.DH(s.e.Private, s.re)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -404,7 +413,10 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
|
|||||||
}
|
}
|
||||||
s.shouldWrite = false
|
s.shouldWrite = false
|
||||||
s.msgIdx++
|
s.msgIdx++
|
||||||
out = s.ss.EncryptAndHash(out, payload)
|
out, err = s.ss.EncryptAndHash(out, payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if s.msgIdx >= len(s.messagePatterns) {
|
if s.msgIdx >= len(s.messagePatterns) {
|
||||||
cs1, cs2 := s.ss.Split()
|
cs1, cs2 := s.ss.Split()
|
||||||
|
|||||||
@ -199,7 +199,8 @@ func (NoiseSuite) TestVectors(c *C) {
|
|||||||
if (i-len(configI.Pattern.Messages))%2 != 0 {
|
if (i-len(configI.Pattern.Messages))%2 != 0 {
|
||||||
enc, dec = csW1, csR1
|
enc, dec = csW1, csR1
|
||||||
}
|
}
|
||||||
encrypted := enc.Encrypt(nil, nil, payload)
|
encrypted, err := enc.Encrypt(nil, nil, payload)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
c.Assert(fmt.Sprintf("%x", encrypted), Equals, string(splitLine[1]))
|
c.Assert(fmt.Sprintf("%x", encrypted), Equals, string(splitLine[1]))
|
||||||
decrypted, err := dec.Decrypt(nil, nil, encrypted)
|
decrypted, err := dec.Decrypt(nil, nil, encrypted)
|
||||||
c.Assert(err, IsNil)
|
c.Assert(err, IsNil)
|
||||||
|
|||||||
@ -178,7 +178,9 @@ func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, pskPlacem
|
|||||||
payload0 := []byte("yellowsubmarine")
|
payload0 := []byte("yellowsubmarine")
|
||||||
payload1 := []byte("submarineyellow")
|
payload1 := []byte("submarineyellow")
|
||||||
fmt.Fprintf(out, "msg_%d_payload=%x\n", len(h.Messages), payload0)
|
fmt.Fprintf(out, "msg_%d_payload=%x\n", len(h.Messages), payload0)
|
||||||
fmt.Fprintf(out, "msg_%d_ciphertext=%x\n", len(h.Messages), cs0.Encrypt(nil, nil, payload0))
|
ciphertext0, _ := cs0.Encrypt(nil, nil, payload0)
|
||||||
|
fmt.Fprintf(out, "msg_%d_ciphertext=%x\n", len(h.Messages), ciphertext0)
|
||||||
fmt.Fprintf(out, "msg_%d_payload=%x\n", len(h.Messages)+1, payload1)
|
fmt.Fprintf(out, "msg_%d_payload=%x\n", len(h.Messages)+1, payload1)
|
||||||
fmt.Fprintf(out, "msg_%d_ciphertext=%x\n", len(h.Messages)+1, cs1.Encrypt(nil, nil, payload1))
|
ciphertext1, _ := cs1.Encrypt(nil, nil, payload1)
|
||||||
|
fmt.Fprintf(out, "msg_%d_ciphertext=%x\n", len(h.Messages)+1, ciphertext1)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user