diff --git a/noise_test.go b/noise_test.go index ab7c0ca..e8431df 100644 --- a/noise_test.go +++ b/noise_test.go @@ -518,6 +518,65 @@ func (NoiseSuite) TestHandshakeRollback_rs(c *C) { c.Assert(msg, DeepEquals, expected) } +func (NoiseSuite) TestSetNonce(c *C) { + cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA512) + rngI := new(RandomInc) + rngR := new(RandomInc) + *rngR = 1 + + hsI, _ := NewHandshakeState(Config{ + CipherSuite: cs, + Random: rngI, + Pattern: HandshakeNN, + Initiator: true, + }) + hsR, _ := NewHandshakeState(Config{ + CipherSuite: cs, + Random: rngR, + Pattern: HandshakeNN, + Initiator: false, + }) + + msg, _, _, _ := hsI.WriteMessage(nil, nil) + res, _, _, err := hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(res, HasLen, 0) + + msg, csR0, csR1, _ := hsR.WriteMessage(nil, nil) + res, csI0, csI1, err := hsI.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(res, HasLen, 0) + + c.Assert(csI0.Nonce(), Equals, uint64(0)) + c.Assert(csI1.Nonce(), Equals, uint64(0)) + c.Assert(csR0.Nonce(), Equals, uint64(0)) + c.Assert(csR1.Nonce(), Equals, uint64(0)) + + const n = 1234 + clientMessage := []byte("msg1") + csI0.SetNonce(n) + msg, err = csI0.Encrypt(nil, nil, clientMessage) + c.Assert(err, IsNil) + // decrypt with incorrect nonce + _, err = csR0.Decrypt(nil, nil, msg) + c.Assert(err, NotNil) + // decrypt with correct nonce + csR0.SetNonce(n) + res, err = csR0.Decrypt(nil, nil, msg) + c.Assert(err, IsNil) + c.Assert(string(clientMessage), Equals, string(res)) + + c.Assert(csI0.Nonce(), Equals, uint64(n+1)) + c.Assert(csI1.Nonce(), Equals, uint64(0)) + c.Assert(csR0.Nonce(), Equals, uint64(n+1)) + c.Assert(csR1.Nonce(), Equals, uint64(0)) + + serverMessage := []byte("msg2") + csR1.SetNonce(MaxNonce + 1) + _, err = csR1.Encrypt(nil, nil, serverMessage) + c.Assert(err, Equals, ErrMaxNonce) +} + func (NoiseSuite) TestRekey(c *C) { rng := new(RandomInc) diff --git a/state.go b/state.go index 6e9577f..985eea5 100644 --- a/state.go +++ b/state.go @@ -86,6 +86,11 @@ func (s *CipherState) Nonce() uint64 { return s.n } +// SetNonce sets the current value of n. +func (s *CipherState) SetNonce(n uint64) { + s.n = n +} + func (s *CipherState) Rekey() { var zeros [32]byte var out []byte