diff --git a/noise_test.go b/noise_test.go index 9634cfa..2839d33 100644 --- a/noise_test.go +++ b/noise_test.go @@ -27,7 +27,7 @@ func (NoiseSuite) TestN(c *C) { cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256) rng := new(RandomInc) staticR := cs.GenerateKeypair(rng) - hs := NewHandshakeState(cs, rng, HandshakeN, true, nil, nil, nil, staticR.Public, nil) + hs := NewHandshakeState(cs, rng, HandshakeN, true, nil, nil, nil, nil, staticR.Public, nil) hello, _, _ := hs.WriteMessage(nil, nil) expected, _ := hex.DecodeString("358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd1662548331a3d1e93b490263abc7a4633867f4") @@ -39,7 +39,7 @@ func (NoiseSuite) TestX(c *C) { rng := new(RandomInc) staticI := cs.GenerateKeypair(rng) staticR := cs.GenerateKeypair(rng) - hs := NewHandshakeState(cs, rng, HandshakeX, true, nil, &staticI, nil, staticR.Public, nil) + hs := NewHandshakeState(cs, rng, HandshakeX, true, nil, nil, &staticI, nil, staticR.Public, nil) hello, _, _ := hs.WriteMessage(nil, nil) expected, _ := hex.DecodeString("79a631eede1bf9c98f12032cdeadd0e7a079398fc786b88cc846ec89af85a51ad203cd28d81cf65a2da637f557a05728b3ae4abdc3a42d1cda5f719d6cf41d7f2cf1b1c5af10e38a09a9bb7e3b1d589a99492cc50293eaa1f3f391b59bb6990d") @@ -52,8 +52,8 @@ func (NoiseSuite) TestNN(c *C) { rngR := new(RandomInc) *rngR = 1 - hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, nil, nil, nil, nil) - hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, nil, nil, nil, nil, nil) + hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, nil, nil, nil, nil, nil) + hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, nil, nil, nil, nil, nil, nil) msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) c.Assert(msg, HasLen, 35) @@ -80,8 +80,8 @@ func (NoiseSuite) TestXX(c *C) { staticI := cs.GenerateKeypair(rngI) staticR := cs.GenerateKeypair(rngR) - hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, &staticI, nil, nil, nil) - hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, &staticR, nil, nil, nil) + hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, nil, &staticI, nil, nil, nil) + hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, nil, &staticR, nil, nil, nil) msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) c.Assert(msg, HasLen, 35) @@ -114,8 +114,8 @@ func (NoiseSuite) TestIK(c *C) { staticI := cs.GenerateKeypair(rngI) staticR := cs.GenerateKeypair(rngR) - hsI := NewHandshakeState(cs, rngI, HandshakeIK, true, []byte("ABC"), &staticI, nil, staticR.Public, nil) - hsR := NewHandshakeState(cs, rngR, HandshakeIK, false, []byte("ABC"), &staticR, nil, nil, nil) + hsI := NewHandshakeState(cs, rngI, HandshakeIK, true, []byte("ABC"), nil, &staticI, nil, staticR.Public, nil) + hsR := NewHandshakeState(cs, rngR, HandshakeIK, false, []byte("ABC"), nil, &staticR, nil, nil, nil) msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) c.Assert(msg, HasLen, 99) @@ -143,8 +143,8 @@ func (NoiseSuite) TestXE(c *C) { staticR := cs.GenerateKeypair(rngR) ephR := cs.GenerateKeypair(rngR) - hsI := NewHandshakeState(cs, rngI, HandshakeXE, true, nil, &staticI, nil, staticR.Public, ephR.Public) - hsR := NewHandshakeState(cs, rngR, HandshakeXE, false, nil, &staticR, &ephR, nil, nil) + hsI := NewHandshakeState(cs, rngI, HandshakeXE, true, nil, nil, &staticI, nil, staticR.Public, ephR.Public) + hsR := NewHandshakeState(cs, rngR, HandshakeXE, false, nil, nil, &staticR, &ephR, nil, nil) msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) c.Assert(msg, HasLen, 51) @@ -177,8 +177,8 @@ func (NoiseSuite) TestXXRoundtrip(c *C) { staticI := cs.GenerateKeypair(rngI) staticR := cs.GenerateKeypair(rngR) - hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, &staticI, nil, nil, nil) - hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, &staticR, nil, nil, nil) + hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, nil, &staticI, nil, nil, nil) + hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, nil, &staticR, nil, nil, nil) // -> e msg, _, _ := hsI.WriteMessage(nil, []byte("abcdef")) @@ -220,3 +220,39 @@ func (NoiseSuite) TestXXRoundtrip(c *C) { c.Assert(err, IsNil) c.Assert(string(res), Equals, "worri") } + +func (NoiseSuite) TestPSK_NN_Roundtrip(c *C) { + cs := NewCipherSuite(DH25519, CipherChaChaPoly, HashBLAKE2b) + rngI := new(RandomInc) + rngR := new(RandomInc) + *rngR = 1 + + hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, []byte("supersecret"), nil, nil, nil, nil) + hsR := NewHandshakeState(cs, rngI, HandshakeNN, false, nil, []byte("supersecret"), nil, nil, nil, nil) + + // -> e + msg, _, _ := hsI.WriteMessage(nil, nil) + c.Assert(msg, HasLen, 48) + res, _, _, err := hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(res, HasLen, 0) + + // <- e, dhee + msg, csR0, csR1 := hsR.WriteMessage(nil, nil) + c.Assert(msg, HasLen, 48) + res, csI0, csI1, err := hsI.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(res, HasLen, 0) + + // transport I -> R + msg = csI0.Encrypt(nil, nil, []byte("foo")) + res, err = csR0.Decrypt(nil, nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "foo") + + // transport R -> I + msg = csR1.Encrypt(nil, nil, []byte("bar")) + res, err = csI1.Decrypt(nil, nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "bar") +} diff --git a/patterns.go b/patterns.go index b9d8144..c17543d 100644 --- a/patterns.go +++ b/patterns.go @@ -84,7 +84,7 @@ var HandshakeXN = HandshakePattern{ var HandshakeIN = HandshakePattern{ Name: "IN", Messages: [][]MessagePattern{ - {MessagePatternS, MessagePatternE}, + {MessagePatternE, MessagePatternS}, {MessagePatternE, MessagePatternDHEE, MessagePatternDHES}, }, } @@ -139,7 +139,7 @@ var HandshakeXX = HandshakePattern{ var HandshakeIX = HandshakePattern{ Name: "IX", Messages: [][]MessagePattern{ - {MessagePatternS, MessagePatternE}, + {MessagePatternE, MessagePatternS}, {MessagePatternE, MessagePatternDHEE, MessagePatternDHES, MessagePatternS, MessagePatternDHSE}, }, } diff --git a/state.go b/state.go index f5aeec6..ede2c54 100644 --- a/state.go +++ b/state.go @@ -119,16 +119,18 @@ type HandshakeState struct { e DHKey // local ephemeral keypair rs []byte // remote party's static public key re []byte // remote party's ephemeral public key + psk bool messagePatterns [][]MessagePattern shouldWrite bool msgIdx int rng io.Reader } -func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern HandshakePattern, initiator bool, prologue []byte, newS, newE *DHKey, newRS, newRE []byte) *HandshakeState { +func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern HandshakePattern, initiator bool, prologue, presharedKey []byte, newS, newE *DHKey, newRS, newRE []byte) *HandshakeState { hs := &HandshakeState{ rs: newRS, re: newRE, + psk: len(presharedKey) > 0, messagePatterns: newHandshakePattern.Messages, shouldWrite: initiator, rng: rng, @@ -140,8 +142,15 @@ func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern Handsh if newS != nil { hs.s = *newS } - hs.InitializeSymmetric([]byte("Noise_" + newHandshakePattern.Name + "_" + string(cs.Name()))) + namePrefix := "Noise_" + if hs.psk { + namePrefix = "NoisePSK_" + } + hs.InitializeSymmetric([]byte(namePrefix + newHandshakePattern.Name + "_" + string(cs.Name()))) hs.MixHash(prologue) + if hs.psk { + hs.MixHash(presharedKey) + } for _, m := range newHandshakePattern.InitiatorPreMessages { switch { case initiator && m == MessagePatternS: @@ -184,7 +193,11 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState switch msg { case MessagePatternE: s.e = s.cs.GenerateKeypair(s.rng) - out = s.EncryptAndHash(out, s.e.Public) + out = append(out, s.e.Public...) + s.MixHash(s.e.Public) + if s.psk { + s.MixKey(s.e.Public) + } case MessagePatternS: if len(s.s.Public) == 0 { panic("noise: invalid state, s.Public is nil") @@ -227,7 +240,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, switch msg { case MessagePatternE, MessagePatternS: expected := s.cs.DHLen() - if s.hasKey { + if msg == MessagePatternS && s.hasKey { expected += 16 } if len(message) < expected { @@ -235,7 +248,15 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, } switch msg { case MessagePatternE: - s.re, err = s.DecryptAndHash(s.re[:0], message[:expected]) + if cap(s.re) < s.cs.DHLen() { + s.re = make([]byte, s.cs.DHLen()) + } + s.re = s.re[:s.cs.DHLen()] + copy(s.re, message) + s.MixHash(s.re) + if s.psk { + s.MixKey(s.re) + } case MessagePatternS: if len(s.rs) > 0 { panic("noise: invalid state, rs is not nil")