diff --git a/noise_test.go b/noise_test.go index b41ed98..4e00235 100644 --- a/noise_test.go +++ b/noise_test.go @@ -312,3 +312,32 @@ func (NoiseSuite) TestPSK_XX(c *C) { expected, _ := hex.DecodeString("2b9c628158a517e3984dc619245d4b9cd73561944f266181b183812ca73499881e30f6e7eeb576c258acc713c2c62874fd1beb76b122f6303f974109aefd7e2a") c.Assert(msg, DeepEquals, expected) } + +func (NoiseSuite) TestHandshakeRollback(c *C) { + cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA512) + rngI := new(RandomInc) + rngR := new(RandomInc) + *rngR = 1 + + hsI := NewHandshakeState(Config{CipherSuite: cs, Random: rngI, Pattern: HandshakeNN, Initiator: true}) + hsR := NewHandshakeState(Config{CipherSuite: cs, Random: rngR, Pattern: HandshakeNN, Initiator: false}) + + msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) + c.Assert(msg, HasLen, 35) + res, _, _, err := hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "abc") + + msg, _, _ = hsR.WriteMessage(nil, []byte("defg")) + c.Assert(msg, HasLen, 52) + prev := msg[1] + msg[1] = msg[1] + 1 + _, _, _, err = hsI.ReadMessage(nil, msg) + c.Assert(err, Not(IsNil)) + msg[1] = prev + res, _, _, err = hsI.ReadMessage(nil, msg) + c.Assert(string(res), Equals, "defg") + + expected, _ := hex.DecodeString("07a37cbc142093c8b755dc1b10e86cb426374ad16aa853ed0bdfc0b2b86d1c7c5e4dc9545d41b3280f4586a5481829e1e24ec5a0") + c.Assert(msg, DeepEquals, expected) +} diff --git a/state.go b/state.go index 2538dd3..dba6491 100644 --- a/state.go +++ b/state.go @@ -66,6 +66,9 @@ type symmetricState struct { hasPSK bool ck []byte h []byte + + prevCK []byte + prevH []byte } func (s *symmetricState) InitializeSymmetric(handshakeName []byte) { @@ -137,6 +140,27 @@ func (s *symmetricState) Split() (*CipherState, *CipherState) { return s1, s2 } +func (s *symmetricState) Checkpoint() { + if len(s.ck) > cap(s.prevCK) { + s.prevCK = make([]byte, len(s.ck)) + } + s.prevCK = s.prevCK[:len(s.ck)] + copy(s.prevCK, s.ck) + + if len(s.h) > cap(s.prevH) { + s.prevH = make([]byte, len(s.h)) + } + s.prevH = s.prevH[:len(s.h)] + copy(s.prevH, s.h) +} + +func (s *symmetricState) Rollback() { + s.ck = s.ck[:len(s.prevCK)] + copy(s.ck, s.prevCK) + s.h = s.h[:len(s.prevH)] + copy(s.h, s.prevH) +} + // A MessagePattern is a single message or operation used in a Noise handshake. type MessagePattern int @@ -340,6 +364,8 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, panic("noise: no handshake messages left") } + s.ss.Checkpoint() + var err error for _, msg := range s.messagePatterns[s.msgIdx] { switch msg { @@ -369,6 +395,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, s.rs, err = s.ss.DecryptAndHash(s.rs[:0], message[:expected]) } if err != nil { + s.ss.Rollback() return nil, nil, nil, err } message = message[expected:] @@ -382,12 +409,13 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs)) } } - s.shouldWrite = true - s.msgIdx++ out, err = s.ss.DecryptAndHash(out, message) if err != nil { + s.ss.Rollback() return nil, nil, nil, err } + s.shouldWrite = true + s.msgIdx++ if s.msgIdx >= len(s.messagePatterns) { cs1, cs2 := s.ss.Split()