mirror of https://github.com/waku-org/noise.git
psk mode: update for redesign
The PSK mode has been redesigned in the latest revision of Noise, which WireGuard is using. This patch updates the library to use this new construction. It adds a outputs parameter to HKDF, a PresharedKeyPlacement config parameter, as well as a PSK token. This has been tested against the latest WireGuard git master, and the two are compatible. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
6902797927
commit
6a0d1cad82
25
hkdf.go
25
hkdf.go
|
@ -5,13 +5,19 @@ import (
|
|||
"hash"
|
||||
)
|
||||
|
||||
func hkdf(h func() hash.Hash, out1, out2, chainingKey, inputKeyMaterial []byte) ([]byte, []byte) {
|
||||
func hkdf(h func() hash.Hash, outputs int, out1, out2, out3, chainingKey, inputKeyMaterial []byte) ([]byte, []byte, []byte) {
|
||||
if len(out1) > 0 {
|
||||
panic("len(out1) > 0")
|
||||
}
|
||||
if len(out2) > 0 {
|
||||
panic("len(out2) > 0")
|
||||
}
|
||||
if len(out3) > 0 {
|
||||
panic("len(out3) > 0")
|
||||
}
|
||||
if outputs > 3 {
|
||||
panic("outputs > 3")
|
||||
}
|
||||
|
||||
tempMAC := hmac.New(h, chainingKey)
|
||||
tempMAC.Write(inputKeyMaterial)
|
||||
|
@ -21,10 +27,23 @@ func hkdf(h func() hash.Hash, out1, out2, chainingKey, inputKeyMaterial []byte)
|
|||
out1MAC.Write([]byte{0x01})
|
||||
out1 = out1MAC.Sum(out1)
|
||||
|
||||
if outputs == 1 {
|
||||
return out1, nil, nil
|
||||
}
|
||||
|
||||
out2MAC := hmac.New(h, tempKey)
|
||||
out2MAC.Write(out1)
|
||||
out2MAC.Write([]byte{0x02})
|
||||
out2 = out2MAC.Sum(tempKey[:0])
|
||||
out2 = out2MAC.Sum(out2)
|
||||
|
||||
return out1, out2
|
||||
if outputs == 2 {
|
||||
return out1, out2, nil
|
||||
}
|
||||
|
||||
out3MAC := hmac.New(h, tempKey)
|
||||
out3MAC.Write(out2)
|
||||
out3MAC.Write([]byte{0x03})
|
||||
out3 = out3MAC.Sum(out3)
|
||||
|
||||
return out1, out2, out3
|
||||
}
|
||||
|
|
53
state.go
53
state.go
|
@ -9,6 +9,7 @@ package noise
|
|||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
|
@ -63,7 +64,6 @@ func (s *CipherState) Cipher() Cipher {
|
|||
type symmetricState struct {
|
||||
CipherState
|
||||
hasK bool
|
||||
hasPSK bool
|
||||
ck []byte
|
||||
h []byte
|
||||
|
||||
|
@ -88,7 +88,7 @@ func (s *symmetricState) MixKey(dhOutput []byte) {
|
|||
s.n = 0
|
||||
s.hasK = true
|
||||
var hk []byte
|
||||
s.ck, hk = hkdf(s.cs.Hash, s.ck[:0], s.k[:0], s.ck, dhOutput)
|
||||
s.ck, hk, _ = hkdf(s.cs.Hash, 2, s.ck[:0], s.k[:0], nil, s.ck, dhOutput)
|
||||
copy(s.k[:], hk)
|
||||
s.c = s.cs.Cipher(s.k)
|
||||
}
|
||||
|
@ -100,11 +100,15 @@ func (s *symmetricState) MixHash(data []byte) {
|
|||
s.h = h.Sum(s.h[:0])
|
||||
}
|
||||
|
||||
func (s *symmetricState) MixPresharedKey(presharedKey []byte) {
|
||||
func (s *symmetricState) MixKeyAndHash(data []byte) {
|
||||
var hk []byte
|
||||
var temp []byte
|
||||
s.ck, temp = hkdf(s.cs.Hash, s.ck[:0], nil, s.ck, presharedKey)
|
||||
s.ck, temp, hk = hkdf(s.cs.Hash, 3, s.ck[:0], temp, s.k[:0], s.ck, data)
|
||||
s.MixHash(temp)
|
||||
s.hasPSK = true
|
||||
copy(s.k[:], hk)
|
||||
s.c = s.cs.Cipher(s.k)
|
||||
s.n = 0
|
||||
s.hasK = true
|
||||
}
|
||||
|
||||
func (s *symmetricState) EncryptAndHash(out, plaintext []byte) []byte {
|
||||
|
@ -132,7 +136,7 @@ func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) {
|
|||
|
||||
func (s *symmetricState) Split() (*CipherState, *CipherState) {
|
||||
s1, s2 := &CipherState{cs: s.cs}, &CipherState{cs: s.cs}
|
||||
hk1, hk2 := hkdf(s.cs.Hash, s1.k[:0], s2.k[:0], s.ck, nil)
|
||||
hk1, hk2, _ := hkdf(s.cs.Hash, 2, s1.k[:0], s2.k[:0], nil, s.ck, nil)
|
||||
copy(s1.k[:], hk1)
|
||||
copy(s2.k[:], hk2)
|
||||
s1.c = s.cs.Cipher(s1.k)
|
||||
|
@ -180,6 +184,7 @@ const (
|
|||
MessagePatternDHES
|
||||
MessagePatternDHSE
|
||||
MessagePatternDHSS
|
||||
MessagePatternPSK
|
||||
)
|
||||
|
||||
// MaxMsgLen is the maximum number of bytes that can be sent in a single Noise
|
||||
|
@ -194,6 +199,7 @@ type HandshakeState struct {
|
|||
e DHKey // local ephemeral keypair
|
||||
rs []byte // remote party's static public key
|
||||
re []byte // remote party's ephemeral public key
|
||||
psk []byte // preshared key, maybe zero length
|
||||
messagePatterns [][]MessagePattern
|
||||
shouldWrite bool
|
||||
msgIdx int
|
||||
|
@ -221,9 +227,13 @@ type Config struct {
|
|||
// be identical on both sides for the handshake to succeed.
|
||||
Prologue []byte
|
||||
|
||||
// PresharedKey is the optional pre-shared key for the handshake.
|
||||
// PresharedKey is the optional preshared key for the handshake.
|
||||
PresharedKey []byte
|
||||
|
||||
// PresharedKeyPlacement specifies the placement position of the PSK token
|
||||
// when PresharedKey is specified
|
||||
PresharedKeyPlacement int
|
||||
|
||||
// StaticKeypair is this peer's static keypair, required if part of the
|
||||
// handshake.
|
||||
StaticKeypair DHKey
|
||||
|
@ -247,6 +257,7 @@ func NewHandshakeState(c Config) *HandshakeState {
|
|||
s: c.StaticKeypair,
|
||||
e: c.EphemeralKeypair,
|
||||
rs: c.PeerStatic,
|
||||
psk: c.PresharedKey,
|
||||
messagePatterns: c.Pattern.Messages,
|
||||
shouldWrite: c.Initiator,
|
||||
rng: c.Random,
|
||||
|
@ -259,15 +270,21 @@ func NewHandshakeState(c Config) *HandshakeState {
|
|||
copy(hs.re, c.PeerEphemeral)
|
||||
}
|
||||
hs.ss.cs = c.CipherSuite
|
||||
namePrefix := "Noise_"
|
||||
if len(c.PresharedKey) > 0 {
|
||||
namePrefix = "NoisePSK_"
|
||||
pskModifier := ""
|
||||
if len(hs.psk) > 0 {
|
||||
if len(hs.psk) != 32 {
|
||||
panic("noise: specification mandates 256-bit preshared keys")
|
||||
}
|
||||
pskModifier = fmt.Sprintf("psk%d", c.PresharedKeyPlacement)
|
||||
hs.messagePatterns = append([][]MessagePattern(nil), hs.messagePatterns...)
|
||||
if (c.PresharedKeyPlacement == 0) {
|
||||
hs.messagePatterns[0] = append([]MessagePattern{MessagePatternPSK}, hs.messagePatterns[0]...)
|
||||
} else {
|
||||
hs.messagePatterns[c.PresharedKeyPlacement - 1] = append(hs.messagePatterns[c.PresharedKeyPlacement - 1], MessagePatternPSK)
|
||||
}
|
||||
}
|
||||
hs.ss.InitializeSymmetric([]byte(namePrefix + c.Pattern.Name + "_" + string(hs.ss.cs.Name())))
|
||||
hs.ss.InitializeSymmetric([]byte("Noise_" + c.Pattern.Name + pskModifier + "_" + string(hs.ss.cs.Name())))
|
||||
hs.ss.MixHash(c.Prologue)
|
||||
if len(c.PresharedKey) > 0 {
|
||||
hs.ss.MixPresharedKey(c.PresharedKey)
|
||||
}
|
||||
for _, m := range c.Pattern.InitiatorPreMessages {
|
||||
switch {
|
||||
case c.Initiator && m == MessagePatternS:
|
||||
|
@ -318,7 +335,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
|
|||
s.e = s.ss.cs.GenerateKeypair(s.rng)
|
||||
out = append(out, s.e.Public...)
|
||||
s.ss.MixHash(s.e.Public)
|
||||
if s.ss.hasPSK {
|
||||
if len(s.psk) > 0 {
|
||||
s.ss.MixKey(s.e.Public)
|
||||
}
|
||||
case MessagePatternS:
|
||||
|
@ -334,6 +351,8 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
|
|||
s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.re))
|
||||
case MessagePatternDHSS:
|
||||
s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs))
|
||||
case MessagePatternPSK:
|
||||
s.ss.MixKeyAndHash(s.psk)
|
||||
}
|
||||
}
|
||||
s.shouldWrite = false
|
||||
|
@ -385,7 +404,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
|
|||
s.re = s.re[:s.ss.cs.DHLen()]
|
||||
copy(s.re, message)
|
||||
s.ss.MixHash(s.re)
|
||||
if s.ss.hasPSK {
|
||||
if len(s.psk) > 0 {
|
||||
s.ss.MixKey(s.re)
|
||||
}
|
||||
case MessagePatternS:
|
||||
|
@ -407,6 +426,8 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
|
|||
s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.rs))
|
||||
case MessagePatternDHSS:
|
||||
s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs))
|
||||
case MessagePatternPSK:
|
||||
s.ss.MixKeyAndHash(s.psk)
|
||||
}
|
||||
}
|
||||
out, err = s.ss.DecryptAndHash(out, message)
|
||||
|
|
Loading…
Reference in New Issue