Roll back static key in the face of failure

This commit is contained in:
Nate Brown 2021-03-24 20:23:37 -05:00 committed by Jonathan Rudenberg
parent 2492fe189a
commit 5a73e02a33
2 changed files with 53 additions and 0 deletions

View File

@ -465,6 +465,51 @@ func (NoiseSuite) TestHandshakeRollback(c *C) {
c.Assert(msg, DeepEquals, expected)
}
func (NoiseSuite) TestHandshakeRollback_rs(c *C) {
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA512)
rngI := new(RandomInc)
rngR := new(RandomInc)
staticI, _ := cs.GenerateKeypair(rngI)
staticR, _ := cs.GenerateKeypair(rngR)
*rngR = 1
hsI, _ := NewHandshakeState(Config{
CipherSuite: cs,
Random: rngI,
Pattern: HandshakeIX,
Initiator: true,
StaticKeypair: staticI,
})
hsR, _ := NewHandshakeState(Config{
CipherSuite: cs,
Random: rngR,
Pattern: HandshakeIX,
Initiator: false,
StaticKeypair: staticR,
})
msg, _, _, _ := hsI.WriteMessage(nil, []byte("abc"))
c.Assert(msg, HasLen, 67)
res, _, _, err := hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "abc")
msg, _, _, _ = hsR.WriteMessage(nil, []byte("defg"))
c.Assert(msg, HasLen, 100)
prev := msg[1]
msg[1] = msg[1] + 1
_, _, _, err = hsI.ReadMessage(nil, msg)
c.Assert(err, Not(IsNil))
msg[1] = prev
res, _, _, err = hsI.ReadMessage(nil, msg)
c.Assert(string(res), Equals, "defg")
expected, _ := hex.DecodeString("07a37cbc142093c8b755dc1b10e86cb426374ad16aa853ed0bdfc0b2b86d1c7cf66fc41515606de81af64a5364fbc0b2cbd71e0837ea590b72b77ae2caaaa93bc19c167c28236a18e0737d395fe95083e41da26a30a8062faf92ed05bbdc36db2369f19b")
c.Assert(msg, DeepEquals, expected)
}
func (NoiseSuite) TestRekey(c *C) {
rng := new(RandomInc)

View File

@ -406,6 +406,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
return nil, nil, nil, errors.New("noise: no handshake messages left")
}
rsSet := false
s.ss.Checkpoint()
var err error
@ -435,9 +436,13 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
return nil, nil, nil, errors.New("noise: invalid state, rs is not nil")
}
s.rs, err = s.ss.DecryptAndHash(s.rs[:0], message[:expected])
rsSet = true
}
if err != nil {
s.ss.Rollback()
if rsSet {
s.rs = nil
}
return nil, nil, nil, err
}
message = message[expected:]
@ -464,6 +469,9 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
out, err = s.ss.DecryptAndHash(out, message)
if err != nil {
s.ss.Rollback()
if rsSet {
s.rs = nil
}
return nil, nil, nil, err
}
s.shouldWrite = true