mirror of https://github.com/waku-org/noise.git
Ensure that the handshake state doesn’t get lost if decryption fails
This allows decoding potentially malicious messages under certain circumstances. Signed-off-by: Jonathan Rudenberg <jonathan@titanous.com>
This commit is contained in:
parent
bc8ae75ea2
commit
7e06e15681
|
@ -312,3 +312,32 @@ func (NoiseSuite) TestPSK_XX(c *C) {
|
|||
expected, _ := hex.DecodeString("2b9c628158a517e3984dc619245d4b9cd73561944f266181b183812ca73499881e30f6e7eeb576c258acc713c2c62874fd1beb76b122f6303f974109aefd7e2a")
|
||||
c.Assert(msg, DeepEquals, expected)
|
||||
}
|
||||
|
||||
func (NoiseSuite) TestHandshakeRollback(c *C) {
|
||||
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA512)
|
||||
rngI := new(RandomInc)
|
||||
rngR := new(RandomInc)
|
||||
*rngR = 1
|
||||
|
||||
hsI := NewHandshakeState(Config{CipherSuite: cs, Random: rngI, Pattern: HandshakeNN, Initiator: true})
|
||||
hsR := NewHandshakeState(Config{CipherSuite: cs, Random: rngR, Pattern: HandshakeNN, Initiator: false})
|
||||
|
||||
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
|
||||
c.Assert(msg, HasLen, 35)
|
||||
res, _, _, err := hsR.ReadMessage(nil, msg)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(string(res), Equals, "abc")
|
||||
|
||||
msg, _, _ = hsR.WriteMessage(nil, []byte("defg"))
|
||||
c.Assert(msg, HasLen, 52)
|
||||
prev := msg[1]
|
||||
msg[1] = msg[1] + 1
|
||||
_, _, _, err = hsI.ReadMessage(nil, msg)
|
||||
c.Assert(err, Not(IsNil))
|
||||
msg[1] = prev
|
||||
res, _, _, err = hsI.ReadMessage(nil, msg)
|
||||
c.Assert(string(res), Equals, "defg")
|
||||
|
||||
expected, _ := hex.DecodeString("07a37cbc142093c8b755dc1b10e86cb426374ad16aa853ed0bdfc0b2b86d1c7c5e4dc9545d41b3280f4586a5481829e1e24ec5a0")
|
||||
c.Assert(msg, DeepEquals, expected)
|
||||
}
|
||||
|
|
32
state.go
32
state.go
|
@ -66,6 +66,9 @@ type symmetricState struct {
|
|||
hasPSK bool
|
||||
ck []byte
|
||||
h []byte
|
||||
|
||||
prevCK []byte
|
||||
prevH []byte
|
||||
}
|
||||
|
||||
func (s *symmetricState) InitializeSymmetric(handshakeName []byte) {
|
||||
|
@ -137,6 +140,27 @@ func (s *symmetricState) Split() (*CipherState, *CipherState) {
|
|||
return s1, s2
|
||||
}
|
||||
|
||||
func (s *symmetricState) Checkpoint() {
|
||||
if len(s.ck) > cap(s.prevCK) {
|
||||
s.prevCK = make([]byte, len(s.ck))
|
||||
}
|
||||
s.prevCK = s.prevCK[:len(s.ck)]
|
||||
copy(s.prevCK, s.ck)
|
||||
|
||||
if len(s.h) > cap(s.prevH) {
|
||||
s.prevH = make([]byte, len(s.h))
|
||||
}
|
||||
s.prevH = s.prevH[:len(s.h)]
|
||||
copy(s.prevH, s.h)
|
||||
}
|
||||
|
||||
func (s *symmetricState) Rollback() {
|
||||
s.ck = s.ck[:len(s.prevCK)]
|
||||
copy(s.ck, s.prevCK)
|
||||
s.h = s.h[:len(s.prevH)]
|
||||
copy(s.h, s.prevH)
|
||||
}
|
||||
|
||||
// A MessagePattern is a single message or operation used in a Noise handshake.
|
||||
type MessagePattern int
|
||||
|
||||
|
@ -340,6 +364,8 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
|
|||
panic("noise: no handshake messages left")
|
||||
}
|
||||
|
||||
s.ss.Checkpoint()
|
||||
|
||||
var err error
|
||||
for _, msg := range s.messagePatterns[s.msgIdx] {
|
||||
switch msg {
|
||||
|
@ -369,6 +395,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
|
|||
s.rs, err = s.ss.DecryptAndHash(s.rs[:0], message[:expected])
|
||||
}
|
||||
if err != nil {
|
||||
s.ss.Rollback()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
message = message[expected:]
|
||||
|
@ -382,12 +409,13 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
|
|||
s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs))
|
||||
}
|
||||
}
|
||||
s.shouldWrite = true
|
||||
s.msgIdx++
|
||||
out, err = s.ss.DecryptAndHash(out, message)
|
||||
if err != nil {
|
||||
s.ss.Rollback()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
s.shouldWrite = true
|
||||
s.msgIdx++
|
||||
|
||||
if s.msgIdx >= len(s.messagePatterns) {
|
||||
cs1, cs2 := s.ss.Split()
|
||||
|
|
Loading…
Reference in New Issue