mirror of https://github.com/waku-org/noise.git
Implement NoisePSK revision 2
This commit is contained in:
parent
f7b9b28336
commit
0e9c47ad19
|
@ -27,7 +27,7 @@ func (NoiseSuite) TestN(c *C) {
|
||||||
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256)
|
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256)
|
||||||
rng := new(RandomInc)
|
rng := new(RandomInc)
|
||||||
staticR := cs.GenerateKeypair(rng)
|
staticR := cs.GenerateKeypair(rng)
|
||||||
hs := NewHandshakeState(cs, rng, HandshakeN, true, nil, nil, nil, staticR.Public, nil)
|
hs := NewHandshakeState(cs, rng, HandshakeN, true, nil, nil, nil, nil, staticR.Public, nil)
|
||||||
|
|
||||||
hello, _, _ := hs.WriteMessage(nil, nil)
|
hello, _, _ := hs.WriteMessage(nil, nil)
|
||||||
expected, _ := hex.DecodeString("358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd1662548331a3d1e93b490263abc7a4633867f4")
|
expected, _ := hex.DecodeString("358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd1662548331a3d1e93b490263abc7a4633867f4")
|
||||||
|
@ -39,7 +39,7 @@ func (NoiseSuite) TestX(c *C) {
|
||||||
rng := new(RandomInc)
|
rng := new(RandomInc)
|
||||||
staticI := cs.GenerateKeypair(rng)
|
staticI := cs.GenerateKeypair(rng)
|
||||||
staticR := cs.GenerateKeypair(rng)
|
staticR := cs.GenerateKeypair(rng)
|
||||||
hs := NewHandshakeState(cs, rng, HandshakeX, true, nil, &staticI, nil, staticR.Public, nil)
|
hs := NewHandshakeState(cs, rng, HandshakeX, true, nil, nil, &staticI, nil, staticR.Public, nil)
|
||||||
|
|
||||||
hello, _, _ := hs.WriteMessage(nil, nil)
|
hello, _, _ := hs.WriteMessage(nil, nil)
|
||||||
expected, _ := hex.DecodeString("79a631eede1bf9c98f12032cdeadd0e7a079398fc786b88cc846ec89af85a51ad203cd28d81cf65a2da637f557a05728b3ae4abdc3a42d1cda5f719d6cf41d7f2cf1b1c5af10e38a09a9bb7e3b1d589a99492cc50293eaa1f3f391b59bb6990d")
|
expected, _ := hex.DecodeString("79a631eede1bf9c98f12032cdeadd0e7a079398fc786b88cc846ec89af85a51ad203cd28d81cf65a2da637f557a05728b3ae4abdc3a42d1cda5f719d6cf41d7f2cf1b1c5af10e38a09a9bb7e3b1d589a99492cc50293eaa1f3f391b59bb6990d")
|
||||||
|
@ -52,8 +52,8 @@ func (NoiseSuite) TestNN(c *C) {
|
||||||
rngR := new(RandomInc)
|
rngR := new(RandomInc)
|
||||||
*rngR = 1
|
*rngR = 1
|
||||||
|
|
||||||
hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, nil, nil, nil, nil)
|
hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, nil, nil, nil, nil, nil)
|
||||||
hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, nil, nil, nil, nil, nil)
|
hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
|
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
|
||||||
c.Assert(msg, HasLen, 35)
|
c.Assert(msg, HasLen, 35)
|
||||||
|
@ -80,8 +80,8 @@ func (NoiseSuite) TestXX(c *C) {
|
||||||
staticI := cs.GenerateKeypair(rngI)
|
staticI := cs.GenerateKeypair(rngI)
|
||||||
staticR := cs.GenerateKeypair(rngR)
|
staticR := cs.GenerateKeypair(rngR)
|
||||||
|
|
||||||
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, &staticI, nil, nil, nil)
|
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, nil, &staticI, nil, nil, nil)
|
||||||
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, &staticR, nil, nil, nil)
|
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, nil, &staticR, nil, nil, nil)
|
||||||
|
|
||||||
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
|
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
|
||||||
c.Assert(msg, HasLen, 35)
|
c.Assert(msg, HasLen, 35)
|
||||||
|
@ -114,8 +114,8 @@ func (NoiseSuite) TestIK(c *C) {
|
||||||
staticI := cs.GenerateKeypair(rngI)
|
staticI := cs.GenerateKeypair(rngI)
|
||||||
staticR := cs.GenerateKeypair(rngR)
|
staticR := cs.GenerateKeypair(rngR)
|
||||||
|
|
||||||
hsI := NewHandshakeState(cs, rngI, HandshakeIK, true, []byte("ABC"), &staticI, nil, staticR.Public, nil)
|
hsI := NewHandshakeState(cs, rngI, HandshakeIK, true, []byte("ABC"), nil, &staticI, nil, staticR.Public, nil)
|
||||||
hsR := NewHandshakeState(cs, rngR, HandshakeIK, false, []byte("ABC"), &staticR, nil, nil, nil)
|
hsR := NewHandshakeState(cs, rngR, HandshakeIK, false, []byte("ABC"), nil, &staticR, nil, nil, nil)
|
||||||
|
|
||||||
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
|
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
|
||||||
c.Assert(msg, HasLen, 99)
|
c.Assert(msg, HasLen, 99)
|
||||||
|
@ -143,8 +143,8 @@ func (NoiseSuite) TestXE(c *C) {
|
||||||
staticR := cs.GenerateKeypair(rngR)
|
staticR := cs.GenerateKeypair(rngR)
|
||||||
ephR := cs.GenerateKeypair(rngR)
|
ephR := cs.GenerateKeypair(rngR)
|
||||||
|
|
||||||
hsI := NewHandshakeState(cs, rngI, HandshakeXE, true, nil, &staticI, nil, staticR.Public, ephR.Public)
|
hsI := NewHandshakeState(cs, rngI, HandshakeXE, true, nil, nil, &staticI, nil, staticR.Public, ephR.Public)
|
||||||
hsR := NewHandshakeState(cs, rngR, HandshakeXE, false, nil, &staticR, &ephR, nil, nil)
|
hsR := NewHandshakeState(cs, rngR, HandshakeXE, false, nil, nil, &staticR, &ephR, nil, nil)
|
||||||
|
|
||||||
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
|
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
|
||||||
c.Assert(msg, HasLen, 51)
|
c.Assert(msg, HasLen, 51)
|
||||||
|
@ -177,8 +177,8 @@ func (NoiseSuite) TestXXRoundtrip(c *C) {
|
||||||
staticI := cs.GenerateKeypair(rngI)
|
staticI := cs.GenerateKeypair(rngI)
|
||||||
staticR := cs.GenerateKeypair(rngR)
|
staticR := cs.GenerateKeypair(rngR)
|
||||||
|
|
||||||
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, &staticI, nil, nil, nil)
|
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, nil, &staticI, nil, nil, nil)
|
||||||
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, &staticR, nil, nil, nil)
|
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, nil, &staticR, nil, nil, nil)
|
||||||
|
|
||||||
// -> e
|
// -> e
|
||||||
msg, _, _ := hsI.WriteMessage(nil, []byte("abcdef"))
|
msg, _, _ := hsI.WriteMessage(nil, []byte("abcdef"))
|
||||||
|
@ -220,3 +220,39 @@ func (NoiseSuite) TestXXRoundtrip(c *C) {
|
||||||
c.Assert(err, IsNil)
|
c.Assert(err, IsNil)
|
||||||
c.Assert(string(res), Equals, "worri")
|
c.Assert(string(res), Equals, "worri")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (NoiseSuite) TestPSK_NN_Roundtrip(c *C) {
|
||||||
|
cs := NewCipherSuite(DH25519, CipherChaChaPoly, HashBLAKE2b)
|
||||||
|
rngI := new(RandomInc)
|
||||||
|
rngR := new(RandomInc)
|
||||||
|
*rngR = 1
|
||||||
|
|
||||||
|
hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, []byte("supersecret"), nil, nil, nil, nil)
|
||||||
|
hsR := NewHandshakeState(cs, rngI, HandshakeNN, false, nil, []byte("supersecret"), nil, nil, nil, nil)
|
||||||
|
|
||||||
|
// -> e
|
||||||
|
msg, _, _ := hsI.WriteMessage(nil, nil)
|
||||||
|
c.Assert(msg, HasLen, 48)
|
||||||
|
res, _, _, err := hsR.ReadMessage(nil, msg)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
|
c.Assert(res, HasLen, 0)
|
||||||
|
|
||||||
|
// <- e, dhee
|
||||||
|
msg, csR0, csR1 := hsR.WriteMessage(nil, nil)
|
||||||
|
c.Assert(msg, HasLen, 48)
|
||||||
|
res, csI0, csI1, err := hsI.ReadMessage(nil, msg)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
|
c.Assert(res, HasLen, 0)
|
||||||
|
|
||||||
|
// transport I -> R
|
||||||
|
msg = csI0.Encrypt(nil, nil, []byte("foo"))
|
||||||
|
res, err = csR0.Decrypt(nil, nil, msg)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
|
c.Assert(string(res), Equals, "foo")
|
||||||
|
|
||||||
|
// transport R -> I
|
||||||
|
msg = csR1.Encrypt(nil, nil, []byte("bar"))
|
||||||
|
res, err = csI1.Decrypt(nil, nil, msg)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
|
c.Assert(string(res), Equals, "bar")
|
||||||
|
}
|
||||||
|
|
|
@ -84,7 +84,7 @@ var HandshakeXN = HandshakePattern{
|
||||||
var HandshakeIN = HandshakePattern{
|
var HandshakeIN = HandshakePattern{
|
||||||
Name: "IN",
|
Name: "IN",
|
||||||
Messages: [][]MessagePattern{
|
Messages: [][]MessagePattern{
|
||||||
{MessagePatternS, MessagePatternE},
|
{MessagePatternE, MessagePatternS},
|
||||||
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES},
|
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -139,7 +139,7 @@ var HandshakeXX = HandshakePattern{
|
||||||
var HandshakeIX = HandshakePattern{
|
var HandshakeIX = HandshakePattern{
|
||||||
Name: "IX",
|
Name: "IX",
|
||||||
Messages: [][]MessagePattern{
|
Messages: [][]MessagePattern{
|
||||||
{MessagePatternS, MessagePatternE},
|
{MessagePatternE, MessagePatternS},
|
||||||
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES, MessagePatternS, MessagePatternDHSE},
|
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES, MessagePatternS, MessagePatternDHSE},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
31
state.go
31
state.go
|
@ -119,16 +119,18 @@ type HandshakeState struct {
|
||||||
e DHKey // local ephemeral keypair
|
e DHKey // local ephemeral keypair
|
||||||
rs []byte // remote party's static public key
|
rs []byte // remote party's static public key
|
||||||
re []byte // remote party's ephemeral public key
|
re []byte // remote party's ephemeral public key
|
||||||
|
psk bool
|
||||||
messagePatterns [][]MessagePattern
|
messagePatterns [][]MessagePattern
|
||||||
shouldWrite bool
|
shouldWrite bool
|
||||||
msgIdx int
|
msgIdx int
|
||||||
rng io.Reader
|
rng io.Reader
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern HandshakePattern, initiator bool, prologue []byte, newS, newE *DHKey, newRS, newRE []byte) *HandshakeState {
|
func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern HandshakePattern, initiator bool, prologue, presharedKey []byte, newS, newE *DHKey, newRS, newRE []byte) *HandshakeState {
|
||||||
hs := &HandshakeState{
|
hs := &HandshakeState{
|
||||||
rs: newRS,
|
rs: newRS,
|
||||||
re: newRE,
|
re: newRE,
|
||||||
|
psk: len(presharedKey) > 0,
|
||||||
messagePatterns: newHandshakePattern.Messages,
|
messagePatterns: newHandshakePattern.Messages,
|
||||||
shouldWrite: initiator,
|
shouldWrite: initiator,
|
||||||
rng: rng,
|
rng: rng,
|
||||||
|
@ -140,8 +142,15 @@ func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern Handsh
|
||||||
if newS != nil {
|
if newS != nil {
|
||||||
hs.s = *newS
|
hs.s = *newS
|
||||||
}
|
}
|
||||||
hs.InitializeSymmetric([]byte("Noise_" + newHandshakePattern.Name + "_" + string(cs.Name())))
|
namePrefix := "Noise_"
|
||||||
|
if hs.psk {
|
||||||
|
namePrefix = "NoisePSK_"
|
||||||
|
}
|
||||||
|
hs.InitializeSymmetric([]byte(namePrefix + newHandshakePattern.Name + "_" + string(cs.Name())))
|
||||||
hs.MixHash(prologue)
|
hs.MixHash(prologue)
|
||||||
|
if hs.psk {
|
||||||
|
hs.MixHash(presharedKey)
|
||||||
|
}
|
||||||
for _, m := range newHandshakePattern.InitiatorPreMessages {
|
for _, m := range newHandshakePattern.InitiatorPreMessages {
|
||||||
switch {
|
switch {
|
||||||
case initiator && m == MessagePatternS:
|
case initiator && m == MessagePatternS:
|
||||||
|
@ -184,7 +193,11 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
|
||||||
switch msg {
|
switch msg {
|
||||||
case MessagePatternE:
|
case MessagePatternE:
|
||||||
s.e = s.cs.GenerateKeypair(s.rng)
|
s.e = s.cs.GenerateKeypair(s.rng)
|
||||||
out = s.EncryptAndHash(out, s.e.Public)
|
out = append(out, s.e.Public...)
|
||||||
|
s.MixHash(s.e.Public)
|
||||||
|
if s.psk {
|
||||||
|
s.MixKey(s.e.Public)
|
||||||
|
}
|
||||||
case MessagePatternS:
|
case MessagePatternS:
|
||||||
if len(s.s.Public) == 0 {
|
if len(s.s.Public) == 0 {
|
||||||
panic("noise: invalid state, s.Public is nil")
|
panic("noise: invalid state, s.Public is nil")
|
||||||
|
@ -227,7 +240,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
|
||||||
switch msg {
|
switch msg {
|
||||||
case MessagePatternE, MessagePatternS:
|
case MessagePatternE, MessagePatternS:
|
||||||
expected := s.cs.DHLen()
|
expected := s.cs.DHLen()
|
||||||
if s.hasKey {
|
if msg == MessagePatternS && s.hasKey {
|
||||||
expected += 16
|
expected += 16
|
||||||
}
|
}
|
||||||
if len(message) < expected {
|
if len(message) < expected {
|
||||||
|
@ -235,7 +248,15 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
|
||||||
}
|
}
|
||||||
switch msg {
|
switch msg {
|
||||||
case MessagePatternE:
|
case MessagePatternE:
|
||||||
s.re, err = s.DecryptAndHash(s.re[:0], message[:expected])
|
if cap(s.re) < s.cs.DHLen() {
|
||||||
|
s.re = make([]byte, s.cs.DHLen())
|
||||||
|
}
|
||||||
|
s.re = s.re[:s.cs.DHLen()]
|
||||||
|
copy(s.re, message)
|
||||||
|
s.MixHash(s.re)
|
||||||
|
if s.psk {
|
||||||
|
s.MixKey(s.re)
|
||||||
|
}
|
||||||
case MessagePatternS:
|
case MessagePatternS:
|
||||||
if len(s.rs) > 0 {
|
if len(s.rs) > 0 {
|
||||||
panic("noise: invalid state, rs is not nil")
|
panic("noise: invalid state, rs is not nil")
|
||||||
|
|
Loading…
Reference in New Issue