From b11a33c4aeec0f9a7e721b9665286bbc49ee0056 Mon Sep 17 00:00:00 2001 From: Jonathan Rudenberg Date: Mon, 16 Nov 2015 13:09:47 -0500 Subject: [PATCH] Add handshake Config struct Signed-off-by: Jonathan Rudenberg --- noise_test.go | 38 ++++++++++++------------ state.go | 81 ++++++++++++++++++++++++++++----------------------- 2 files changed, 64 insertions(+), 55 deletions(-) diff --git a/noise_test.go b/noise_test.go index e6a9bd4..9807ddb 100644 --- a/noise_test.go +++ b/noise_test.go @@ -27,7 +27,7 @@ func (NoiseSuite) TestN(c *C) { cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256) rng := new(RandomInc) staticR := cs.GenerateKeypair(rng) - hs := NewHandshakeState(cs, rng, HandshakeN, true, nil, nil, nil, nil, staticR.Public, nil) + hs := NewHandshakeState(Config{CipherSuite: cs, Random: rng, Pattern: HandshakeN, Initiator: true, PeerStatic: staticR.Public}) hello, _, _ := hs.WriteMessage(nil, nil) expected, _ := hex.DecodeString("358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd1662548331a3d1e93b490263abc7a4633867f4") @@ -39,7 +39,7 @@ func (NoiseSuite) TestX(c *C) { rng := new(RandomInc) staticI := cs.GenerateKeypair(rng) staticR := cs.GenerateKeypair(rng) - hs := NewHandshakeState(cs, rng, HandshakeX, true, nil, nil, &staticI, nil, staticR.Public, nil) + hs := NewHandshakeState(Config{CipherSuite: cs, Random: rng, Pattern: HandshakeX, Initiator: true, StaticKeypair: staticI, PeerStatic: staticR.Public}) hello, _, _ := hs.WriteMessage(nil, nil) expected, _ := hex.DecodeString("79a631eede1bf9c98f12032cdeadd0e7a079398fc786b88cc846ec89af85a51ad203cd28d81cf65a2da637f557a05728b3ae4abdc3a42d1cda5f719d6cf41d7f2cf1b1c5af10e38a09a9bb7e3b1d589a99492cc50293eaa1f3f391b59bb6990d") @@ -52,8 +52,8 @@ func (NoiseSuite) TestNN(c *C) { rngR := new(RandomInc) *rngR = 1 - hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, nil, nil, nil, nil, nil) - hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, nil, nil, nil, nil, nil, nil) + hsI := NewHandshakeState(Config{CipherSuite: cs, Random: rngI, Pattern: HandshakeNN, Initiator: true}) + hsR := NewHandshakeState(Config{CipherSuite: cs, Random: rngR, Pattern: HandshakeNN, Initiator: false}) msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) c.Assert(msg, HasLen, 35) @@ -80,8 +80,8 @@ func (NoiseSuite) TestXX(c *C) { staticI := cs.GenerateKeypair(rngI) staticR := cs.GenerateKeypair(rngR) - hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, nil, &staticI, nil, nil, nil) - hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, nil, &staticR, nil, nil, nil) + hsI := NewHandshakeState(Config{CipherSuite: cs, Random: rngI, Pattern: HandshakeXX, Initiator: true, StaticKeypair: staticI}) + hsR := NewHandshakeState(Config{CipherSuite: cs, Random: rngR, Pattern: HandshakeXX, StaticKeypair: staticR}) msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) c.Assert(msg, HasLen, 35) @@ -114,8 +114,8 @@ func (NoiseSuite) TestIK(c *C) { staticI := cs.GenerateKeypair(rngI) staticR := cs.GenerateKeypair(rngR) - hsI := NewHandshakeState(cs, rngI, HandshakeIK, true, []byte("ABC"), nil, &staticI, nil, staticR.Public, nil) - hsR := NewHandshakeState(cs, rngR, HandshakeIK, false, []byte("ABC"), nil, &staticR, nil, nil, nil) + hsI := NewHandshakeState(Config{CipherSuite: cs, Random: rngI, Pattern: HandshakeIK, Initiator: true, Prologue: []byte("ABC"), StaticKeypair: staticI, PeerStatic: staticR.Public}) + hsR := NewHandshakeState(Config{CipherSuite: cs, Random: rngR, Pattern: HandshakeIK, Prologue: []byte("ABC"), StaticKeypair: staticR}) msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) c.Assert(msg, HasLen, 99) @@ -143,8 +143,8 @@ func (NoiseSuite) TestXE(c *C) { staticR := cs.GenerateKeypair(rngR) ephR := cs.GenerateKeypair(rngR) - hsI := NewHandshakeState(cs, rngI, HandshakeXE, true, nil, nil, &staticI, nil, staticR.Public, ephR.Public) - hsR := NewHandshakeState(cs, rngR, HandshakeXE, false, nil, nil, &staticR, &ephR, nil, nil) + hsI := NewHandshakeState(Config{CipherSuite: cs, Random: rngI, Pattern: HandshakeXE, Initiator: true, StaticKeypair: staticI, PeerStatic: staticR.Public, PeerEphemeral: ephR.Public}) + hsR := NewHandshakeState(Config{CipherSuite: cs, Random: rngR, Pattern: HandshakeXE, StaticKeypair: staticR, EphemeralKeypair: ephR}) msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) c.Assert(msg, HasLen, 51) @@ -177,8 +177,8 @@ func (NoiseSuite) TestXXRoundtrip(c *C) { staticI := cs.GenerateKeypair(rngI) staticR := cs.GenerateKeypair(rngR) - hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, nil, &staticI, nil, nil, nil) - hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, nil, &staticR, nil, nil, nil) + hsI := NewHandshakeState(Config{CipherSuite: cs, Random: rngI, Pattern: HandshakeXX, Initiator: true, StaticKeypair: staticI}) + hsR := NewHandshakeState(Config{CipherSuite: cs, Random: rngR, Pattern: HandshakeXX, StaticKeypair: staticR}) // -> e msg, _, _ := hsI.WriteMessage(nil, []byte("abcdef")) @@ -227,8 +227,8 @@ func (NoiseSuite) TestPSK_NN_Roundtrip(c *C) { rngR := new(RandomInc) *rngR = 1 - hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, []byte("supersecret"), nil, nil, nil, nil) - hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, nil, []byte("supersecret"), nil, nil, nil, nil) + hsI := NewHandshakeState(Config{CipherSuite: cs, Random: rngI, Pattern: HandshakeNN, Initiator: true, PresharedKey: []byte("supersecret")}) + hsR := NewHandshakeState(Config{CipherSuite: cs, Random: rngR, Pattern: HandshakeNN, PresharedKey: []byte("supersecret")}) // -> e msg, _, _ := hsI.WriteMessage(nil, nil) @@ -263,7 +263,7 @@ func (NoiseSuite) TestPSK_X(c *C) { staticI := cs.GenerateKeypair(rng) staticR := cs.GenerateKeypair(rng) - hs := NewHandshakeState(cs, rng, HandshakeX, true, nil, []byte{0x01, 0x02, 0x03}, &staticI, nil, staticR.Public, nil) + hs := NewHandshakeState(Config{CipherSuite: cs, Random: rng, Pattern: HandshakeX, Initiator: true, PresharedKey: []byte{0x01, 0x02, 0x03}, StaticKeypair: staticI, PeerStatic: staticR.Public}) msg, _, _ := hs.WriteMessage(nil, nil) c.Assert(msg, HasLen, 96) @@ -279,8 +279,8 @@ func (NoiseSuite) TestPSK_NN(c *C) { prologue := []byte{0x01, 0x02, 0x03} psk := []byte{0x04, 0x05, 0x06} - hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, prologue, psk, nil, nil, nil, nil) - hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, prologue, psk, nil, nil, nil, nil) + hsI := NewHandshakeState(Config{CipherSuite: cs, Random: rngI, Pattern: HandshakeNN, Initiator: true, Prologue: prologue, PresharedKey: psk}) + hsR := NewHandshakeState(Config{CipherSuite: cs, Random: rngR, Pattern: HandshakeNN, Prologue: prologue, PresharedKey: psk}) msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) c.Assert(msg, HasLen, 51) @@ -309,8 +309,8 @@ func (NoiseSuite) TestPSK_XX(c *C) { prologue := []byte{0x01, 0x02, 0x03} psk := []byte{0x04, 0x05, 0x06} - hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, prologue, psk, &staticI, nil, nil, nil) - hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, prologue, psk, &staticR, nil, nil, nil) + hsI := NewHandshakeState(Config{CipherSuite: cs, Random: rngI, Pattern: HandshakeXX, Initiator: true, Prologue: prologue, PresharedKey: psk, StaticKeypair: staticI}) + hsR := NewHandshakeState(Config{CipherSuite: cs, Random: rngR, Pattern: HandshakeXX, Prologue: prologue, PresharedKey: psk, StaticKeypair: staticR}) msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) c.Assert(msg, HasLen, 51) diff --git a/state.go b/state.go index 6431eaa..d8914c3 100644 --- a/state.go +++ b/state.go @@ -133,52 +133,61 @@ type HandshakeState struct { rng io.Reader } -func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern HandshakePattern, initiator bool, prologue, presharedKey []byte, newS, newE *DHKey, newRS, newRE []byte) *HandshakeState { +type Config struct { + CipherSuite CipherSuite + Random io.Reader + Pattern HandshakePattern + Initiator bool + Prologue []byte + PresharedKey []byte + StaticKeypair DHKey + EphemeralKeypair DHKey + PeerStatic []byte + PeerEphemeral []byte +} + +func NewHandshakeState(c Config) *HandshakeState { hs := &HandshakeState{ - rs: newRS, - re: newRE, - messagePatterns: newHandshakePattern.Messages, - shouldWrite: initiator, - rng: rng, - } - hs.SymmetricState.cs = cs - if newE != nil { - hs.e = *newE - } - if newS != nil { - hs.s = *newS + s: c.StaticKeypair, + e: c.EphemeralKeypair, + rs: c.PeerStatic, + re: c.PeerEphemeral, + messagePatterns: c.Pattern.Messages, + shouldWrite: c.Initiator, + rng: c.Random, } + hs.SymmetricState.cs = c.CipherSuite namePrefix := "Noise_" - if hs.hasPSK { + if len(c.PresharedKey) > 0 { namePrefix = "NoisePSK_" } - hs.InitializeSymmetric([]byte(namePrefix + newHandshakePattern.Name + "_" + string(cs.Name()))) - hs.MixHash(prologue) - if len(presharedKey) > 0 { - hs.MixPresharedKey(presharedKey) + hs.InitializeSymmetric([]byte(namePrefix + c.Pattern.Name + "_" + string(hs.cs.Name()))) + hs.MixHash(c.Prologue) + if len(c.PresharedKey) > 0 { + hs.MixPresharedKey(c.PresharedKey) } - for _, m := range newHandshakePattern.InitiatorPreMessages { + for _, m := range c.Pattern.InitiatorPreMessages { switch { - case initiator && m == MessagePatternS: - hs.MixHash(newS.Public) - case initiator && m == MessagePatternE: - hs.MixHash(newE.Public) - case !initiator && m == MessagePatternS: - hs.MixHash(newRS) - case !initiator && m == MessagePatternE: - hs.MixHash(newRE) + case c.Initiator && m == MessagePatternS: + hs.MixHash(hs.s.Public) + case c.Initiator && m == MessagePatternE: + hs.MixHash(hs.e.Public) + case !c.Initiator && m == MessagePatternS: + hs.MixHash(hs.rs) + case !c.Initiator && m == MessagePatternE: + hs.MixHash(hs.re) } } - for _, m := range newHandshakePattern.ResponderPreMessages { + for _, m := range c.Pattern.ResponderPreMessages { switch { - case !initiator && m == MessagePatternS: - hs.MixHash(newS.Public) - case !initiator && m == MessagePatternE: - hs.MixHash(newE.Public) - case initiator && m == MessagePatternS: - hs.MixHash(newRS) - case initiator && m == MessagePatternE: - hs.MixHash(newRE) + case !c.Initiator && m == MessagePatternS: + hs.MixHash(hs.s.Public) + case !c.Initiator && m == MessagePatternE: + hs.MixHash(hs.e.Public) + case c.Initiator && m == MessagePatternS: + hs.MixHash(hs.rs) + case c.Initiator && m == MessagePatternE: + hs.MixHash(hs.re) } } return hs