diff --git a/noise_test.go b/noise_test.go index ea613c7..5f57066 100644 --- a/noise_test.go +++ b/noise_test.go @@ -465,6 +465,51 @@ func (NoiseSuite) TestHandshakeRollback(c *C) { c.Assert(msg, DeepEquals, expected) } +func (NoiseSuite) TestHandshakeRollback_rs(c *C) { + cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA512) + rngI := new(RandomInc) + rngR := new(RandomInc) + + staticI, _ := cs.GenerateKeypair(rngI) + staticR, _ := cs.GenerateKeypair(rngR) + + *rngR = 1 + + hsI, _ := NewHandshakeState(Config{ + CipherSuite: cs, + Random: rngI, + Pattern: HandshakeIX, + Initiator: true, + StaticKeypair: staticI, + }) + hsR, _ := NewHandshakeState(Config{ + CipherSuite: cs, + Random: rngR, + Pattern: HandshakeIX, + Initiator: false, + StaticKeypair: staticR, + }) + + msg, _, _, _ := hsI.WriteMessage(nil, []byte("abc")) + c.Assert(msg, HasLen, 67) + res, _, _, err := hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "abc") + + msg, _, _, _ = hsR.WriteMessage(nil, []byte("defg")) + c.Assert(msg, HasLen, 100) + prev := msg[1] + msg[1] = msg[1] + 1 + _, _, _, err = hsI.ReadMessage(nil, msg) + c.Assert(err, Not(IsNil)) + msg[1] = prev + res, _, _, err = hsI.ReadMessage(nil, msg) + c.Assert(string(res), Equals, "defg") + + expected, _ := hex.DecodeString("07a37cbc142093c8b755dc1b10e86cb426374ad16aa853ed0bdfc0b2b86d1c7cf66fc41515606de81af64a5364fbc0b2cbd71e0837ea590b72b77ae2caaaa93bc19c167c28236a18e0737d395fe95083e41da26a30a8062faf92ed05bbdc36db2369f19b") + c.Assert(msg, DeepEquals, expected) +} + func (NoiseSuite) TestRekey(c *C) { rng := new(RandomInc) diff --git a/state.go b/state.go index 8ce7c56..c4f3161 100644 --- a/state.go +++ b/state.go @@ -406,6 +406,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, return nil, nil, nil, errors.New("noise: no handshake messages left") } + rsSet := false s.ss.Checkpoint() var err error @@ -435,9 +436,13 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, return nil, nil, nil, errors.New("noise: invalid state, rs is not nil") } s.rs, err = s.ss.DecryptAndHash(s.rs[:0], message[:expected]) + rsSet = true } if err != nil { s.ss.Rollback() + if rsSet { + s.rs = nil + } return nil, nil, nil, err } message = message[expected:] @@ -464,6 +469,9 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, out, err = s.ss.DecryptAndHash(out, message) if err != nil { s.ss.Rollback() + if rsSet { + s.rs = nil + } return nil, nil, nil, err } s.shouldWrite = true