MixPresharedKey and new vectors

This commit is contained in:
Jonathan Rudenberg 2015-11-16 11:37:33 -05:00
parent 0e9c47ad19
commit d760bc3534
2 changed files with 100 additions and 17 deletions

View File

@ -124,12 +124,12 @@ func (NoiseSuite) TestIK(c *C) {
c.Assert(string(res), Equals, "abc") c.Assert(string(res), Equals, "abc")
msg, _, _ = hsR.WriteMessage(nil, []byte("defg")) msg, _, _ = hsR.WriteMessage(nil, []byte("defg"))
c.Assert(msg, HasLen, 68) c.Assert(msg, HasLen, 52)
res, _, _, err = hsI.ReadMessage(nil, msg) res, _, _, err = hsI.ReadMessage(nil, msg)
c.Assert(err, IsNil) c.Assert(err, IsNil)
c.Assert(string(res), Equals, "defg") c.Assert(string(res), Equals, "defg")
expected, _ := hex.DecodeString("5a491c3d8524aee516e7edccba51433ebe651002f0f79fd79dc6a4bf65ecd7b13543f1cc7910a367ffc3686f9c03e62e7555a9411133bb3194f27a9433507b30d858d578") expected, _ := hex.DecodeString("5869aff450549732cbaaed5e5df9b30a6da31cb0e5742bad5ad4a1a768f1a67b7555a94199d0ce2972e0861b06c2152419a278de")
c.Assert(msg, DeepEquals, expected) c.Assert(msg, DeepEquals, expected)
} }
@ -153,7 +153,7 @@ func (NoiseSuite) TestXE(c *C) {
c.Assert(string(res), Equals, "abc") c.Assert(string(res), Equals, "abc")
msg, _, _ = hsR.WriteMessage(nil, []byte("defg")) msg, _, _ = hsR.WriteMessage(nil, []byte("defg"))
c.Assert(msg, HasLen, 68) c.Assert(msg, HasLen, 52)
res, _, _, err = hsI.ReadMessage(nil, msg) res, _, _, err = hsI.ReadMessage(nil, msg)
c.Assert(err, IsNil) c.Assert(err, IsNil)
c.Assert(string(res), Equals, "defg") c.Assert(string(res), Equals, "defg")
@ -164,7 +164,7 @@ func (NoiseSuite) TestXE(c *C) {
c.Assert(err, IsNil) c.Assert(err, IsNil)
c.Assert(res, HasLen, 0) c.Assert(res, HasLen, 0)
expected, _ := hex.DecodeString("08439f380b6f128a1465840d558f06abb1141cf5708a9dcf573d6e4fae01f90f7c9b8ef856bdc483df643a9d240ab6d38d9af9f3812ef44a465e32f8227a7c8b") expected, _ := hex.DecodeString("08439f380b6f128a1465840d558f06abb1141cf5708a9dcf573d6e4fae01f90fd68dec89b26b249f2c4c61add5a1dbcf0a652ef015d7dbe0e80e9ea9af0aa7a2")
c.Assert(msg, DeepEquals, expected) c.Assert(msg, DeepEquals, expected)
} }
@ -228,7 +228,7 @@ func (NoiseSuite) TestPSK_NN_Roundtrip(c *C) {
*rngR = 1 *rngR = 1
hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, []byte("supersecret"), nil, nil, nil, nil) hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, []byte("supersecret"), nil, nil, nil, nil)
hsR := NewHandshakeState(cs, rngI, HandshakeNN, false, nil, []byte("supersecret"), nil, nil, nil, nil) hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, nil, []byte("supersecret"), nil, nil, nil, nil)
// -> e // -> e
msg, _, _ := hsI.WriteMessage(nil, nil) msg, _, _ := hsI.WriteMessage(nil, nil)
@ -256,3 +256,80 @@ func (NoiseSuite) TestPSK_NN_Roundtrip(c *C) {
c.Assert(err, IsNil) c.Assert(err, IsNil)
c.Assert(string(res), Equals, "bar") c.Assert(string(res), Equals, "bar")
} }
func (NoiseSuite) TestPSK_X(c *C) {
cs := NewCipherSuite(DH25519, CipherChaChaPoly, HashSHA256)
rng := new(RandomInc)
staticI := cs.GenerateKeypair(rng)
staticR := cs.GenerateKeypair(rng)
hs := NewHandshakeState(cs, rng, HandshakeX, true, nil, []byte{0x01, 0x02, 0x03}, &staticI, nil, staticR.Public, nil)
msg, _, _ := hs.WriteMessage(nil, nil)
c.Assert(msg, HasLen, 96)
expected, _ := hex.DecodeString("79a631eede1bf9c98f12032cdeadd0e7a079398fc786b88cc846ec89af85a51a983a01a35059140decfb16a5748b5673a261e4bb69a11f0d698cf6d5117f99eadcacaa2082307089ab2c633970cdbe1da510833a29ba3211174d35780b58e99c")
c.Assert(msg, DeepEquals, expected)
}
func (NoiseSuite) TestPSK_NN(c *C) {
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA512)
rngI := new(RandomInc)
rngR := new(RandomInc)
*rngR = 1
prologue := []byte{0x01, 0x02, 0x03}
psk := []byte{0x04, 0x05, 0x06}
hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, prologue, psk, nil, nil, nil, nil)
hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, prologue, psk, nil, nil, nil, nil)
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
c.Assert(msg, HasLen, 51)
res, _, _, err := hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "abc")
msg, _, _ = hsR.WriteMessage(nil, []byte("defg"))
c.Assert(msg, HasLen, 52)
res, _, _, err = hsI.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "defg")
expected, _ := hex.DecodeString("07a37cbc142093c8b755dc1b10e86cb426374ad16aa853ed0bdfc0b2b86d1c7c4f28d0b09ff91e2ff6bb55bb99bc74436056c0d1")
c.Assert(msg, DeepEquals, expected)
}
func (NoiseSuite) TestPSK_XX(c *C) {
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256)
rngI := new(RandomInc)
rngR := new(RandomInc)
*rngR = 1
staticI := cs.GenerateKeypair(rngI)
staticR := cs.GenerateKeypair(rngR)
prologue := []byte{0x01, 0x02, 0x03}
psk := []byte{0x04, 0x05, 0x06}
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, prologue, psk, &staticI, nil, nil, nil)
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, prologue, psk, &staticR, nil, nil, nil)
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
c.Assert(msg, HasLen, 51)
res, _, _, err := hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "abc")
msg, _, _ = hsR.WriteMessage(nil, []byte("defg"))
c.Assert(msg, HasLen, 100)
res, _, _, err = hsI.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "defg")
msg, _, _ = hsI.WriteMessage(nil, nil)
c.Assert(msg, HasLen, 64)
res, _, _, err = hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(res, HasLen, 0)
expected, _ := hex.DecodeString("eb8f3a6d5b68c7048cf61cbbff4a19959fed3ad315ef0d088f00681f3f38295d5d2aee59874e22cf9e86c2df3aaea03449435de887bab9bde1ee7ef392785fdf")
c.Assert(msg, DeepEquals, expected)
}

View File

@ -26,7 +26,8 @@ func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) {
type SymmetricState struct { type SymmetricState struct {
CipherState CipherState
hasKey bool hasK bool
hasPSK bool
ck []byte ck []byte
h []byte h []byte
} }
@ -46,7 +47,7 @@ func (s *SymmetricState) InitializeSymmetric(handshakeName []byte) {
func (s *SymmetricState) MixKey(dhOutput []byte) { func (s *SymmetricState) MixKey(dhOutput []byte) {
s.n = 0 s.n = 0
s.hasKey = true s.hasK = true
var hk []byte var hk []byte
s.ck, hk = HKDF(s.cs.Hash, s.ck[:0], s.k[:0], s.ck, dhOutput) s.ck, hk = HKDF(s.cs.Hash, s.ck[:0], s.k[:0], s.ck, dhOutput)
copy(s.k[:], hk) copy(s.k[:], hk)
@ -60,8 +61,15 @@ func (s *SymmetricState) MixHash(data []byte) {
s.h = h.Sum(s.h[:0]) s.h = h.Sum(s.h[:0])
} }
func (s *SymmetricState) MixPresharedKey(presharedKey []byte) {
var temp []byte
s.ck, temp = HKDF(s.cs.Hash, s.ck[:0], nil, s.ck, presharedKey)
s.MixHash(temp)
s.hasPSK = true
}
func (s *SymmetricState) EncryptAndHash(out, plaintext []byte) []byte { func (s *SymmetricState) EncryptAndHash(out, plaintext []byte) []byte {
if !s.hasKey { if !s.hasK {
s.MixHash(plaintext) s.MixHash(plaintext)
return append(out, plaintext...) return append(out, plaintext...)
} }
@ -71,7 +79,7 @@ func (s *SymmetricState) EncryptAndHash(out, plaintext []byte) []byte {
} }
func (s *SymmetricState) DecryptAndHash(out, data []byte) ([]byte, error) { func (s *SymmetricState) DecryptAndHash(out, data []byte) ([]byte, error) {
if !s.hasKey { if !s.hasK {
s.MixHash(data) s.MixHash(data)
return append(out, data...), nil return append(out, data...), nil
} }
@ -119,7 +127,6 @@ type HandshakeState struct {
e DHKey // local ephemeral keypair e DHKey // local ephemeral keypair
rs []byte // remote party's static public key rs []byte // remote party's static public key
re []byte // remote party's ephemeral public key re []byte // remote party's ephemeral public key
psk bool
messagePatterns [][]MessagePattern messagePatterns [][]MessagePattern
shouldWrite bool shouldWrite bool
msgIdx int msgIdx int
@ -130,7 +137,6 @@ func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern Handsh
hs := &HandshakeState{ hs := &HandshakeState{
rs: newRS, rs: newRS,
re: newRE, re: newRE,
psk: len(presharedKey) > 0,
messagePatterns: newHandshakePattern.Messages, messagePatterns: newHandshakePattern.Messages,
shouldWrite: initiator, shouldWrite: initiator,
rng: rng, rng: rng,
@ -143,13 +149,13 @@ func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern Handsh
hs.s = *newS hs.s = *newS
} }
namePrefix := "Noise_" namePrefix := "Noise_"
if hs.psk { if hs.hasPSK {
namePrefix = "NoisePSK_" namePrefix = "NoisePSK_"
} }
hs.InitializeSymmetric([]byte(namePrefix + newHandshakePattern.Name + "_" + string(cs.Name()))) hs.InitializeSymmetric([]byte(namePrefix + newHandshakePattern.Name + "_" + string(cs.Name())))
hs.MixHash(prologue) hs.MixHash(prologue)
if hs.psk { if len(presharedKey) > 0 {
hs.MixHash(presharedKey) hs.MixPresharedKey(presharedKey)
} }
for _, m := range newHandshakePattern.InitiatorPreMessages { for _, m := range newHandshakePattern.InitiatorPreMessages {
switch { switch {
@ -195,7 +201,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
s.e = s.cs.GenerateKeypair(s.rng) s.e = s.cs.GenerateKeypair(s.rng)
out = append(out, s.e.Public...) out = append(out, s.e.Public...)
s.MixHash(s.e.Public) s.MixHash(s.e.Public)
if s.psk { if s.hasPSK {
s.MixKey(s.e.Public) s.MixKey(s.e.Public)
} }
case MessagePatternS: case MessagePatternS:
@ -240,7 +246,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
switch msg { switch msg {
case MessagePatternE, MessagePatternS: case MessagePatternE, MessagePatternS:
expected := s.cs.DHLen() expected := s.cs.DHLen()
if msg == MessagePatternS && s.hasKey { if msg == MessagePatternS && s.hasK {
expected += 16 expected += 16
} }
if len(message) < expected { if len(message) < expected {
@ -254,7 +260,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
s.re = s.re[:s.cs.DHLen()] s.re = s.re[:s.cs.DHLen()]
copy(s.re, message) copy(s.re, message)
s.MixHash(s.re) s.MixHash(s.re)
if s.psk { if s.hasPSK {
s.MixKey(s.re) s.MixKey(s.re)
} }
case MessagePatternS: case MessagePatternS: