mirror of https://github.com/waku-org/noise.git
Implement NoisePSK revision 2
This commit is contained in:
parent
f7b9b28336
commit
0e9c47ad19
|
@ -27,7 +27,7 @@ func (NoiseSuite) TestN(c *C) {
|
|||
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256)
|
||||
rng := new(RandomInc)
|
||||
staticR := cs.GenerateKeypair(rng)
|
||||
hs := NewHandshakeState(cs, rng, HandshakeN, true, nil, nil, nil, staticR.Public, nil)
|
||||
hs := NewHandshakeState(cs, rng, HandshakeN, true, nil, nil, nil, nil, staticR.Public, nil)
|
||||
|
||||
hello, _, _ := hs.WriteMessage(nil, nil)
|
||||
expected, _ := hex.DecodeString("358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd1662548331a3d1e93b490263abc7a4633867f4")
|
||||
|
@ -39,7 +39,7 @@ func (NoiseSuite) TestX(c *C) {
|
|||
rng := new(RandomInc)
|
||||
staticI := cs.GenerateKeypair(rng)
|
||||
staticR := cs.GenerateKeypair(rng)
|
||||
hs := NewHandshakeState(cs, rng, HandshakeX, true, nil, &staticI, nil, staticR.Public, nil)
|
||||
hs := NewHandshakeState(cs, rng, HandshakeX, true, nil, nil, &staticI, nil, staticR.Public, nil)
|
||||
|
||||
hello, _, _ := hs.WriteMessage(nil, nil)
|
||||
expected, _ := hex.DecodeString("79a631eede1bf9c98f12032cdeadd0e7a079398fc786b88cc846ec89af85a51ad203cd28d81cf65a2da637f557a05728b3ae4abdc3a42d1cda5f719d6cf41d7f2cf1b1c5af10e38a09a9bb7e3b1d589a99492cc50293eaa1f3f391b59bb6990d")
|
||||
|
@ -52,8 +52,8 @@ func (NoiseSuite) TestNN(c *C) {
|
|||
rngR := new(RandomInc)
|
||||
*rngR = 1
|
||||
|
||||
hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, nil, nil, nil, nil)
|
||||
hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, nil, nil, nil, nil, nil)
|
||||
hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, nil, nil, nil, nil, nil)
|
||||
hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
|
||||
c.Assert(msg, HasLen, 35)
|
||||
|
@ -80,8 +80,8 @@ func (NoiseSuite) TestXX(c *C) {
|
|||
staticI := cs.GenerateKeypair(rngI)
|
||||
staticR := cs.GenerateKeypair(rngR)
|
||||
|
||||
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, &staticI, nil, nil, nil)
|
||||
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, &staticR, nil, nil, nil)
|
||||
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, nil, &staticI, nil, nil, nil)
|
||||
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, nil, &staticR, nil, nil, nil)
|
||||
|
||||
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
|
||||
c.Assert(msg, HasLen, 35)
|
||||
|
@ -114,8 +114,8 @@ func (NoiseSuite) TestIK(c *C) {
|
|||
staticI := cs.GenerateKeypair(rngI)
|
||||
staticR := cs.GenerateKeypair(rngR)
|
||||
|
||||
hsI := NewHandshakeState(cs, rngI, HandshakeIK, true, []byte("ABC"), &staticI, nil, staticR.Public, nil)
|
||||
hsR := NewHandshakeState(cs, rngR, HandshakeIK, false, []byte("ABC"), &staticR, nil, nil, nil)
|
||||
hsI := NewHandshakeState(cs, rngI, HandshakeIK, true, []byte("ABC"), nil, &staticI, nil, staticR.Public, nil)
|
||||
hsR := NewHandshakeState(cs, rngR, HandshakeIK, false, []byte("ABC"), nil, &staticR, nil, nil, nil)
|
||||
|
||||
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
|
||||
c.Assert(msg, HasLen, 99)
|
||||
|
@ -143,8 +143,8 @@ func (NoiseSuite) TestXE(c *C) {
|
|||
staticR := cs.GenerateKeypair(rngR)
|
||||
ephR := cs.GenerateKeypair(rngR)
|
||||
|
||||
hsI := NewHandshakeState(cs, rngI, HandshakeXE, true, nil, &staticI, nil, staticR.Public, ephR.Public)
|
||||
hsR := NewHandshakeState(cs, rngR, HandshakeXE, false, nil, &staticR, &ephR, nil, nil)
|
||||
hsI := NewHandshakeState(cs, rngI, HandshakeXE, true, nil, nil, &staticI, nil, staticR.Public, ephR.Public)
|
||||
hsR := NewHandshakeState(cs, rngR, HandshakeXE, false, nil, nil, &staticR, &ephR, nil, nil)
|
||||
|
||||
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
|
||||
c.Assert(msg, HasLen, 51)
|
||||
|
@ -177,8 +177,8 @@ func (NoiseSuite) TestXXRoundtrip(c *C) {
|
|||
staticI := cs.GenerateKeypair(rngI)
|
||||
staticR := cs.GenerateKeypair(rngR)
|
||||
|
||||
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, &staticI, nil, nil, nil)
|
||||
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, &staticR, nil, nil, nil)
|
||||
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, nil, &staticI, nil, nil, nil)
|
||||
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, nil, &staticR, nil, nil, nil)
|
||||
|
||||
// -> e
|
||||
msg, _, _ := hsI.WriteMessage(nil, []byte("abcdef"))
|
||||
|
@ -220,3 +220,39 @@ func (NoiseSuite) TestXXRoundtrip(c *C) {
|
|||
c.Assert(err, IsNil)
|
||||
c.Assert(string(res), Equals, "worri")
|
||||
}
|
||||
|
||||
func (NoiseSuite) TestPSK_NN_Roundtrip(c *C) {
|
||||
cs := NewCipherSuite(DH25519, CipherChaChaPoly, HashBLAKE2b)
|
||||
rngI := new(RandomInc)
|
||||
rngR := new(RandomInc)
|
||||
*rngR = 1
|
||||
|
||||
hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, []byte("supersecret"), nil, nil, nil, nil)
|
||||
hsR := NewHandshakeState(cs, rngI, HandshakeNN, false, nil, []byte("supersecret"), nil, nil, nil, nil)
|
||||
|
||||
// -> e
|
||||
msg, _, _ := hsI.WriteMessage(nil, nil)
|
||||
c.Assert(msg, HasLen, 48)
|
||||
res, _, _, err := hsR.ReadMessage(nil, msg)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(res, HasLen, 0)
|
||||
|
||||
// <- e, dhee
|
||||
msg, csR0, csR1 := hsR.WriteMessage(nil, nil)
|
||||
c.Assert(msg, HasLen, 48)
|
||||
res, csI0, csI1, err := hsI.ReadMessage(nil, msg)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(res, HasLen, 0)
|
||||
|
||||
// transport I -> R
|
||||
msg = csI0.Encrypt(nil, nil, []byte("foo"))
|
||||
res, err = csR0.Decrypt(nil, nil, msg)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(string(res), Equals, "foo")
|
||||
|
||||
// transport R -> I
|
||||
msg = csR1.Encrypt(nil, nil, []byte("bar"))
|
||||
res, err = csI1.Decrypt(nil, nil, msg)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(string(res), Equals, "bar")
|
||||
}
|
||||
|
|
|
@ -84,7 +84,7 @@ var HandshakeXN = HandshakePattern{
|
|||
var HandshakeIN = HandshakePattern{
|
||||
Name: "IN",
|
||||
Messages: [][]MessagePattern{
|
||||
{MessagePatternS, MessagePatternE},
|
||||
{MessagePatternE, MessagePatternS},
|
||||
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES},
|
||||
},
|
||||
}
|
||||
|
@ -139,7 +139,7 @@ var HandshakeXX = HandshakePattern{
|
|||
var HandshakeIX = HandshakePattern{
|
||||
Name: "IX",
|
||||
Messages: [][]MessagePattern{
|
||||
{MessagePatternS, MessagePatternE},
|
||||
{MessagePatternE, MessagePatternS},
|
||||
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES, MessagePatternS, MessagePatternDHSE},
|
||||
},
|
||||
}
|
||||
|
|
31
state.go
31
state.go
|
@ -119,16 +119,18 @@ type HandshakeState struct {
|
|||
e DHKey // local ephemeral keypair
|
||||
rs []byte // remote party's static public key
|
||||
re []byte // remote party's ephemeral public key
|
||||
psk bool
|
||||
messagePatterns [][]MessagePattern
|
||||
shouldWrite bool
|
||||
msgIdx int
|
||||
rng io.Reader
|
||||
}
|
||||
|
||||
func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern HandshakePattern, initiator bool, prologue []byte, newS, newE *DHKey, newRS, newRE []byte) *HandshakeState {
|
||||
func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern HandshakePattern, initiator bool, prologue, presharedKey []byte, newS, newE *DHKey, newRS, newRE []byte) *HandshakeState {
|
||||
hs := &HandshakeState{
|
||||
rs: newRS,
|
||||
re: newRE,
|
||||
psk: len(presharedKey) > 0,
|
||||
messagePatterns: newHandshakePattern.Messages,
|
||||
shouldWrite: initiator,
|
||||
rng: rng,
|
||||
|
@ -140,8 +142,15 @@ func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern Handsh
|
|||
if newS != nil {
|
||||
hs.s = *newS
|
||||
}
|
||||
hs.InitializeSymmetric([]byte("Noise_" + newHandshakePattern.Name + "_" + string(cs.Name())))
|
||||
namePrefix := "Noise_"
|
||||
if hs.psk {
|
||||
namePrefix = "NoisePSK_"
|
||||
}
|
||||
hs.InitializeSymmetric([]byte(namePrefix + newHandshakePattern.Name + "_" + string(cs.Name())))
|
||||
hs.MixHash(prologue)
|
||||
if hs.psk {
|
||||
hs.MixHash(presharedKey)
|
||||
}
|
||||
for _, m := range newHandshakePattern.InitiatorPreMessages {
|
||||
switch {
|
||||
case initiator && m == MessagePatternS:
|
||||
|
@ -184,7 +193,11 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
|
|||
switch msg {
|
||||
case MessagePatternE:
|
||||
s.e = s.cs.GenerateKeypair(s.rng)
|
||||
out = s.EncryptAndHash(out, s.e.Public)
|
||||
out = append(out, s.e.Public...)
|
||||
s.MixHash(s.e.Public)
|
||||
if s.psk {
|
||||
s.MixKey(s.e.Public)
|
||||
}
|
||||
case MessagePatternS:
|
||||
if len(s.s.Public) == 0 {
|
||||
panic("noise: invalid state, s.Public is nil")
|
||||
|
@ -227,7 +240,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
|
|||
switch msg {
|
||||
case MessagePatternE, MessagePatternS:
|
||||
expected := s.cs.DHLen()
|
||||
if s.hasKey {
|
||||
if msg == MessagePatternS && s.hasKey {
|
||||
expected += 16
|
||||
}
|
||||
if len(message) < expected {
|
||||
|
@ -235,7 +248,15 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
|
|||
}
|
||||
switch msg {
|
||||
case MessagePatternE:
|
||||
s.re, err = s.DecryptAndHash(s.re[:0], message[:expected])
|
||||
if cap(s.re) < s.cs.DHLen() {
|
||||
s.re = make([]byte, s.cs.DHLen())
|
||||
}
|
||||
s.re = s.re[:s.cs.DHLen()]
|
||||
copy(s.re, message)
|
||||
s.MixHash(s.re)
|
||||
if s.psk {
|
||||
s.MixKey(s.re)
|
||||
}
|
||||
case MessagePatternS:
|
||||
if len(s.rs) > 0 {
|
||||
panic("noise: invalid state, rs is not nil")
|
||||
|
|
Loading…
Reference in New Issue