diff --git a/noise_test.go b/noise_test.go index 2839d33..e6a9bd4 100644 --- a/noise_test.go +++ b/noise_test.go @@ -124,12 +124,12 @@ func (NoiseSuite) TestIK(c *C) { c.Assert(string(res), Equals, "abc") msg, _, _ = hsR.WriteMessage(nil, []byte("defg")) - c.Assert(msg, HasLen, 68) + c.Assert(msg, HasLen, 52) res, _, _, err = hsI.ReadMessage(nil, msg) c.Assert(err, IsNil) c.Assert(string(res), Equals, "defg") - expected, _ := hex.DecodeString("5a491c3d8524aee516e7edccba51433ebe651002f0f79fd79dc6a4bf65ecd7b13543f1cc7910a367ffc3686f9c03e62e7555a9411133bb3194f27a9433507b30d858d578") + expected, _ := hex.DecodeString("5869aff450549732cbaaed5e5df9b30a6da31cb0e5742bad5ad4a1a768f1a67b7555a94199d0ce2972e0861b06c2152419a278de") c.Assert(msg, DeepEquals, expected) } @@ -153,7 +153,7 @@ func (NoiseSuite) TestXE(c *C) { c.Assert(string(res), Equals, "abc") msg, _, _ = hsR.WriteMessage(nil, []byte("defg")) - c.Assert(msg, HasLen, 68) + c.Assert(msg, HasLen, 52) res, _, _, err = hsI.ReadMessage(nil, msg) c.Assert(err, IsNil) c.Assert(string(res), Equals, "defg") @@ -164,7 +164,7 @@ func (NoiseSuite) TestXE(c *C) { c.Assert(err, IsNil) c.Assert(res, HasLen, 0) - expected, _ := hex.DecodeString("08439f380b6f128a1465840d558f06abb1141cf5708a9dcf573d6e4fae01f90f7c9b8ef856bdc483df643a9d240ab6d38d9af9f3812ef44a465e32f8227a7c8b") + expected, _ := hex.DecodeString("08439f380b6f128a1465840d558f06abb1141cf5708a9dcf573d6e4fae01f90fd68dec89b26b249f2c4c61add5a1dbcf0a652ef015d7dbe0e80e9ea9af0aa7a2") c.Assert(msg, DeepEquals, expected) } @@ -228,7 +228,7 @@ func (NoiseSuite) TestPSK_NN_Roundtrip(c *C) { *rngR = 1 hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, []byte("supersecret"), nil, nil, nil, nil) - hsR := NewHandshakeState(cs, rngI, HandshakeNN, false, nil, []byte("supersecret"), nil, nil, nil, nil) + hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, nil, []byte("supersecret"), nil, nil, nil, nil) // -> e msg, _, _ := hsI.WriteMessage(nil, nil) @@ -256,3 +256,80 @@ func (NoiseSuite) TestPSK_NN_Roundtrip(c *C) { c.Assert(err, IsNil) c.Assert(string(res), Equals, "bar") } + +func (NoiseSuite) TestPSK_X(c *C) { + cs := NewCipherSuite(DH25519, CipherChaChaPoly, HashSHA256) + rng := new(RandomInc) + staticI := cs.GenerateKeypair(rng) + staticR := cs.GenerateKeypair(rng) + + hs := NewHandshakeState(cs, rng, HandshakeX, true, nil, []byte{0x01, 0x02, 0x03}, &staticI, nil, staticR.Public, nil) + msg, _, _ := hs.WriteMessage(nil, nil) + c.Assert(msg, HasLen, 96) + + expected, _ := hex.DecodeString("79a631eede1bf9c98f12032cdeadd0e7a079398fc786b88cc846ec89af85a51a983a01a35059140decfb16a5748b5673a261e4bb69a11f0d698cf6d5117f99eadcacaa2082307089ab2c633970cdbe1da510833a29ba3211174d35780b58e99c") + c.Assert(msg, DeepEquals, expected) +} + +func (NoiseSuite) TestPSK_NN(c *C) { + cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA512) + rngI := new(RandomInc) + rngR := new(RandomInc) + *rngR = 1 + prologue := []byte{0x01, 0x02, 0x03} + psk := []byte{0x04, 0x05, 0x06} + + hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, prologue, psk, nil, nil, nil, nil) + hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, prologue, psk, nil, nil, nil, nil) + + msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) + c.Assert(msg, HasLen, 51) + res, _, _, err := hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "abc") + + msg, _, _ = hsR.WriteMessage(nil, []byte("defg")) + c.Assert(msg, HasLen, 52) + res, _, _, err = hsI.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "defg") + + expected, _ := hex.DecodeString("07a37cbc142093c8b755dc1b10e86cb426374ad16aa853ed0bdfc0b2b86d1c7c4f28d0b09ff91e2ff6bb55bb99bc74436056c0d1") + c.Assert(msg, DeepEquals, expected) +} + +func (NoiseSuite) TestPSK_XX(c *C) { + cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256) + rngI := new(RandomInc) + rngR := new(RandomInc) + *rngR = 1 + + staticI := cs.GenerateKeypair(rngI) + staticR := cs.GenerateKeypair(rngR) + prologue := []byte{0x01, 0x02, 0x03} + psk := []byte{0x04, 0x05, 0x06} + + hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, prologue, psk, &staticI, nil, nil, nil) + hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, prologue, psk, &staticR, nil, nil, nil) + + msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) + c.Assert(msg, HasLen, 51) + res, _, _, err := hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "abc") + + msg, _, _ = hsR.WriteMessage(nil, []byte("defg")) + c.Assert(msg, HasLen, 100) + res, _, _, err = hsI.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "defg") + + msg, _, _ = hsI.WriteMessage(nil, nil) + c.Assert(msg, HasLen, 64) + res, _, _, err = hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(res, HasLen, 0) + + expected, _ := hex.DecodeString("eb8f3a6d5b68c7048cf61cbbff4a19959fed3ad315ef0d088f00681f3f38295d5d2aee59874e22cf9e86c2df3aaea03449435de887bab9bde1ee7ef392785fdf") + c.Assert(msg, DeepEquals, expected) +} diff --git a/state.go b/state.go index ede2c54..6431eaa 100644 --- a/state.go +++ b/state.go @@ -26,7 +26,8 @@ func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) { type SymmetricState struct { CipherState - hasKey bool + hasK bool + hasPSK bool ck []byte h []byte } @@ -46,7 +47,7 @@ func (s *SymmetricState) InitializeSymmetric(handshakeName []byte) { func (s *SymmetricState) MixKey(dhOutput []byte) { s.n = 0 - s.hasKey = true + s.hasK = true var hk []byte s.ck, hk = HKDF(s.cs.Hash, s.ck[:0], s.k[:0], s.ck, dhOutput) copy(s.k[:], hk) @@ -60,8 +61,15 @@ func (s *SymmetricState) MixHash(data []byte) { s.h = h.Sum(s.h[:0]) } +func (s *SymmetricState) MixPresharedKey(presharedKey []byte) { + var temp []byte + s.ck, temp = HKDF(s.cs.Hash, s.ck[:0], nil, s.ck, presharedKey) + s.MixHash(temp) + s.hasPSK = true +} + func (s *SymmetricState) EncryptAndHash(out, plaintext []byte) []byte { - if !s.hasKey { + if !s.hasK { s.MixHash(plaintext) return append(out, plaintext...) } @@ -71,7 +79,7 @@ func (s *SymmetricState) EncryptAndHash(out, plaintext []byte) []byte { } func (s *SymmetricState) DecryptAndHash(out, data []byte) ([]byte, error) { - if !s.hasKey { + if !s.hasK { s.MixHash(data) return append(out, data...), nil } @@ -119,7 +127,6 @@ type HandshakeState struct { e DHKey // local ephemeral keypair rs []byte // remote party's static public key re []byte // remote party's ephemeral public key - psk bool messagePatterns [][]MessagePattern shouldWrite bool msgIdx int @@ -130,7 +137,6 @@ func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern Handsh hs := &HandshakeState{ rs: newRS, re: newRE, - psk: len(presharedKey) > 0, messagePatterns: newHandshakePattern.Messages, shouldWrite: initiator, rng: rng, @@ -143,13 +149,13 @@ func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern Handsh hs.s = *newS } namePrefix := "Noise_" - if hs.psk { + if hs.hasPSK { namePrefix = "NoisePSK_" } hs.InitializeSymmetric([]byte(namePrefix + newHandshakePattern.Name + "_" + string(cs.Name()))) hs.MixHash(prologue) - if hs.psk { - hs.MixHash(presharedKey) + if len(presharedKey) > 0 { + hs.MixPresharedKey(presharedKey) } for _, m := range newHandshakePattern.InitiatorPreMessages { switch { @@ -195,7 +201,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState s.e = s.cs.GenerateKeypair(s.rng) out = append(out, s.e.Public...) s.MixHash(s.e.Public) - if s.psk { + if s.hasPSK { s.MixKey(s.e.Public) } case MessagePatternS: @@ -240,7 +246,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, switch msg { case MessagePatternE, MessagePatternS: expected := s.cs.DHLen() - if msg == MessagePatternS && s.hasKey { + if msg == MessagePatternS && s.hasK { expected += 16 } if len(message) < expected { @@ -254,7 +260,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, s.re = s.re[:s.cs.DHLen()] copy(s.re, message) s.MixHash(s.re) - if s.psk { + if s.hasPSK { s.MixKey(s.re) } case MessagePatternS: