psk mode: update for redesign

The PSK mode has been redesigned in the latest revision of Noise, which
WireGuard is using. This patch updates the library to use this new
construction. It adds a outputs parameter to HKDF, a
PresharedKeyPlacement config parameter, as well as a PSK token. This has
been tested against the latest WireGuard git master, and the two are
compatible.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2017-05-11 15:56:03 +02:00
parent 6902797927
commit 6a0d1cad82
2 changed files with 59 additions and 19 deletions

25
hkdf.go
View File

@ -5,13 +5,19 @@ import (
"hash" "hash"
) )
func hkdf(h func() hash.Hash, out1, out2, chainingKey, inputKeyMaterial []byte) ([]byte, []byte) { func hkdf(h func() hash.Hash, outputs int, out1, out2, out3, chainingKey, inputKeyMaterial []byte) ([]byte, []byte, []byte) {
if len(out1) > 0 { if len(out1) > 0 {
panic("len(out1) > 0") panic("len(out1) > 0")
} }
if len(out2) > 0 { if len(out2) > 0 {
panic("len(out2) > 0") panic("len(out2) > 0")
} }
if len(out3) > 0 {
panic("len(out3) > 0")
}
if outputs > 3 {
panic("outputs > 3")
}
tempMAC := hmac.New(h, chainingKey) tempMAC := hmac.New(h, chainingKey)
tempMAC.Write(inputKeyMaterial) tempMAC.Write(inputKeyMaterial)
@ -21,10 +27,23 @@ func hkdf(h func() hash.Hash, out1, out2, chainingKey, inputKeyMaterial []byte)
out1MAC.Write([]byte{0x01}) out1MAC.Write([]byte{0x01})
out1 = out1MAC.Sum(out1) out1 = out1MAC.Sum(out1)
if outputs == 1 {
return out1, nil, nil
}
out2MAC := hmac.New(h, tempKey) out2MAC := hmac.New(h, tempKey)
out2MAC.Write(out1) out2MAC.Write(out1)
out2MAC.Write([]byte{0x02}) out2MAC.Write([]byte{0x02})
out2 = out2MAC.Sum(tempKey[:0]) out2 = out2MAC.Sum(out2)
return out1, out2 if outputs == 2 {
return out1, out2, nil
}
out3MAC := hmac.New(h, tempKey)
out3MAC.Write(out2)
out3MAC.Write([]byte{0x03})
out3 = out3MAC.Sum(out3)
return out1, out2, out3
} }

View File

@ -9,6 +9,7 @@ package noise
import ( import (
"crypto/rand" "crypto/rand"
"errors" "errors"
"fmt"
"io" "io"
) )
@ -63,7 +64,6 @@ func (s *CipherState) Cipher() Cipher {
type symmetricState struct { type symmetricState struct {
CipherState CipherState
hasK bool hasK bool
hasPSK bool
ck []byte ck []byte
h []byte h []byte
@ -88,7 +88,7 @@ func (s *symmetricState) MixKey(dhOutput []byte) {
s.n = 0 s.n = 0
s.hasK = true s.hasK = true
var hk []byte var hk []byte
s.ck, hk = hkdf(s.cs.Hash, s.ck[:0], s.k[:0], s.ck, dhOutput) s.ck, hk, _ = hkdf(s.cs.Hash, 2, s.ck[:0], s.k[:0], nil, s.ck, dhOutput)
copy(s.k[:], hk) copy(s.k[:], hk)
s.c = s.cs.Cipher(s.k) s.c = s.cs.Cipher(s.k)
} }
@ -100,11 +100,15 @@ func (s *symmetricState) MixHash(data []byte) {
s.h = h.Sum(s.h[:0]) s.h = h.Sum(s.h[:0])
} }
func (s *symmetricState) MixPresharedKey(presharedKey []byte) { func (s *symmetricState) MixKeyAndHash(data []byte) {
var hk []byte
var temp []byte var temp []byte
s.ck, temp = hkdf(s.cs.Hash, s.ck[:0], nil, s.ck, presharedKey) s.ck, temp, hk = hkdf(s.cs.Hash, 3, s.ck[:0], temp, s.k[:0], s.ck, data)
s.MixHash(temp) s.MixHash(temp)
s.hasPSK = true copy(s.k[:], hk)
s.c = s.cs.Cipher(s.k)
s.n = 0
s.hasK = true
} }
func (s *symmetricState) EncryptAndHash(out, plaintext []byte) []byte { func (s *symmetricState) EncryptAndHash(out, plaintext []byte) []byte {
@ -132,7 +136,7 @@ func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) {
func (s *symmetricState) Split() (*CipherState, *CipherState) { func (s *symmetricState) Split() (*CipherState, *CipherState) {
s1, s2 := &CipherState{cs: s.cs}, &CipherState{cs: s.cs} s1, s2 := &CipherState{cs: s.cs}, &CipherState{cs: s.cs}
hk1, hk2 := hkdf(s.cs.Hash, s1.k[:0], s2.k[:0], s.ck, nil) hk1, hk2, _ := hkdf(s.cs.Hash, 2, s1.k[:0], s2.k[:0], nil, s.ck, nil)
copy(s1.k[:], hk1) copy(s1.k[:], hk1)
copy(s2.k[:], hk2) copy(s2.k[:], hk2)
s1.c = s.cs.Cipher(s1.k) s1.c = s.cs.Cipher(s1.k)
@ -180,6 +184,7 @@ const (
MessagePatternDHES MessagePatternDHES
MessagePatternDHSE MessagePatternDHSE
MessagePatternDHSS MessagePatternDHSS
MessagePatternPSK
) )
// MaxMsgLen is the maximum number of bytes that can be sent in a single Noise // MaxMsgLen is the maximum number of bytes that can be sent in a single Noise
@ -194,6 +199,7 @@ type HandshakeState struct {
e DHKey // local ephemeral keypair e DHKey // local ephemeral keypair
rs []byte // remote party's static public key rs []byte // remote party's static public key
re []byte // remote party's ephemeral public key re []byte // remote party's ephemeral public key
psk []byte // preshared key, maybe zero length
messagePatterns [][]MessagePattern messagePatterns [][]MessagePattern
shouldWrite bool shouldWrite bool
msgIdx int msgIdx int
@ -221,9 +227,13 @@ type Config struct {
// be identical on both sides for the handshake to succeed. // be identical on both sides for the handshake to succeed.
Prologue []byte Prologue []byte
// PresharedKey is the optional pre-shared key for the handshake. // PresharedKey is the optional preshared key for the handshake.
PresharedKey []byte PresharedKey []byte
// PresharedKeyPlacement specifies the placement position of the PSK token
// when PresharedKey is specified
PresharedKeyPlacement int
// StaticKeypair is this peer's static keypair, required if part of the // StaticKeypair is this peer's static keypair, required if part of the
// handshake. // handshake.
StaticKeypair DHKey StaticKeypair DHKey
@ -247,6 +257,7 @@ func NewHandshakeState(c Config) *HandshakeState {
s: c.StaticKeypair, s: c.StaticKeypair,
e: c.EphemeralKeypair, e: c.EphemeralKeypair,
rs: c.PeerStatic, rs: c.PeerStatic,
psk: c.PresharedKey,
messagePatterns: c.Pattern.Messages, messagePatterns: c.Pattern.Messages,
shouldWrite: c.Initiator, shouldWrite: c.Initiator,
rng: c.Random, rng: c.Random,
@ -259,15 +270,21 @@ func NewHandshakeState(c Config) *HandshakeState {
copy(hs.re, c.PeerEphemeral) copy(hs.re, c.PeerEphemeral)
} }
hs.ss.cs = c.CipherSuite hs.ss.cs = c.CipherSuite
namePrefix := "Noise_" pskModifier := ""
if len(c.PresharedKey) > 0 { if len(hs.psk) > 0 {
namePrefix = "NoisePSK_" if len(hs.psk) != 32 {
panic("noise: specification mandates 256-bit preshared keys")
} }
hs.ss.InitializeSymmetric([]byte(namePrefix + c.Pattern.Name + "_" + string(hs.ss.cs.Name()))) pskModifier = fmt.Sprintf("psk%d", c.PresharedKeyPlacement)
hs.messagePatterns = append([][]MessagePattern(nil), hs.messagePatterns...)
if (c.PresharedKeyPlacement == 0) {
hs.messagePatterns[0] = append([]MessagePattern{MessagePatternPSK}, hs.messagePatterns[0]...)
} else {
hs.messagePatterns[c.PresharedKeyPlacement - 1] = append(hs.messagePatterns[c.PresharedKeyPlacement - 1], MessagePatternPSK)
}
}
hs.ss.InitializeSymmetric([]byte("Noise_" + c.Pattern.Name + pskModifier + "_" + string(hs.ss.cs.Name())))
hs.ss.MixHash(c.Prologue) hs.ss.MixHash(c.Prologue)
if len(c.PresharedKey) > 0 {
hs.ss.MixPresharedKey(c.PresharedKey)
}
for _, m := range c.Pattern.InitiatorPreMessages { for _, m := range c.Pattern.InitiatorPreMessages {
switch { switch {
case c.Initiator && m == MessagePatternS: case c.Initiator && m == MessagePatternS:
@ -318,7 +335,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
s.e = s.ss.cs.GenerateKeypair(s.rng) s.e = s.ss.cs.GenerateKeypair(s.rng)
out = append(out, s.e.Public...) out = append(out, s.e.Public...)
s.ss.MixHash(s.e.Public) s.ss.MixHash(s.e.Public)
if s.ss.hasPSK { if len(s.psk) > 0 {
s.ss.MixKey(s.e.Public) s.ss.MixKey(s.e.Public)
} }
case MessagePatternS: case MessagePatternS:
@ -334,6 +351,8 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.re)) s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.re))
case MessagePatternDHSS: case MessagePatternDHSS:
s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs)) s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs))
case MessagePatternPSK:
s.ss.MixKeyAndHash(s.psk)
} }
} }
s.shouldWrite = false s.shouldWrite = false
@ -385,7 +404,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
s.re = s.re[:s.ss.cs.DHLen()] s.re = s.re[:s.ss.cs.DHLen()]
copy(s.re, message) copy(s.re, message)
s.ss.MixHash(s.re) s.ss.MixHash(s.re)
if s.ss.hasPSK { if len(s.psk) > 0 {
s.ss.MixKey(s.re) s.ss.MixKey(s.re)
} }
case MessagePatternS: case MessagePatternS:
@ -407,6 +426,8 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.rs)) s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.rs))
case MessagePatternDHSS: case MessagePatternDHSS:
s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs)) s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs))
case MessagePatternPSK:
s.ss.MixKeyAndHash(s.psk)
} }
} }
out, err = s.ss.DecryptAndHash(out, message) out, err = s.ss.DecryptAndHash(out, message)