diff --git a/hkdf.go b/hkdf.go index 94aa746..2ea494f 100644 --- a/hkdf.go +++ b/hkdf.go @@ -5,13 +5,19 @@ import ( "hash" ) -func hkdf(h func() hash.Hash, out1, out2, chainingKey, inputKeyMaterial []byte) ([]byte, []byte) { +func hkdf(h func() hash.Hash, outputs int, out1, out2, out3, chainingKey, inputKeyMaterial []byte) ([]byte, []byte, []byte) { if len(out1) > 0 { panic("len(out1) > 0") } if len(out2) > 0 { panic("len(out2) > 0") } + if len(out3) > 0 { + panic("len(out3) > 0") + } + if outputs > 3 { + panic("outputs > 3") + } tempMAC := hmac.New(h, chainingKey) tempMAC.Write(inputKeyMaterial) @@ -21,10 +27,23 @@ func hkdf(h func() hash.Hash, out1, out2, chainingKey, inputKeyMaterial []byte) out1MAC.Write([]byte{0x01}) out1 = out1MAC.Sum(out1) + if outputs == 1 { + return out1, nil, nil + } + out2MAC := hmac.New(h, tempKey) out2MAC.Write(out1) out2MAC.Write([]byte{0x02}) - out2 = out2MAC.Sum(tempKey[:0]) + out2 = out2MAC.Sum(out2) - return out1, out2 + if outputs == 2 { + return out1, out2, nil + } + + out3MAC := hmac.New(h, tempKey) + out3MAC.Write(out2) + out3MAC.Write([]byte{0x03}) + out3 = out3MAC.Sum(out3) + + return out1, out2, out3 } diff --git a/state.go b/state.go index dba6491..705d6af 100644 --- a/state.go +++ b/state.go @@ -9,6 +9,7 @@ package noise import ( "crypto/rand" "errors" + "fmt" "io" ) @@ -63,7 +64,6 @@ func (s *CipherState) Cipher() Cipher { type symmetricState struct { CipherState hasK bool - hasPSK bool ck []byte h []byte @@ -88,7 +88,7 @@ func (s *symmetricState) MixKey(dhOutput []byte) { s.n = 0 s.hasK = true var hk []byte - s.ck, hk = hkdf(s.cs.Hash, s.ck[:0], s.k[:0], s.ck, dhOutput) + s.ck, hk, _ = hkdf(s.cs.Hash, 2, s.ck[:0], s.k[:0], nil, s.ck, dhOutput) copy(s.k[:], hk) s.c = s.cs.Cipher(s.k) } @@ -100,11 +100,15 @@ func (s *symmetricState) MixHash(data []byte) { s.h = h.Sum(s.h[:0]) } -func (s *symmetricState) MixPresharedKey(presharedKey []byte) { +func (s *symmetricState) MixKeyAndHash(data []byte) { + var hk []byte var temp []byte - s.ck, temp = hkdf(s.cs.Hash, s.ck[:0], nil, s.ck, presharedKey) + s.ck, temp, hk = hkdf(s.cs.Hash, 3, s.ck[:0], temp, s.k[:0], s.ck, data) s.MixHash(temp) - s.hasPSK = true + copy(s.k[:], hk) + s.c = s.cs.Cipher(s.k) + s.n = 0 + s.hasK = true } func (s *symmetricState) EncryptAndHash(out, plaintext []byte) []byte { @@ -132,7 +136,7 @@ func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) { func (s *symmetricState) Split() (*CipherState, *CipherState) { s1, s2 := &CipherState{cs: s.cs}, &CipherState{cs: s.cs} - hk1, hk2 := hkdf(s.cs.Hash, s1.k[:0], s2.k[:0], s.ck, nil) + hk1, hk2, _ := hkdf(s.cs.Hash, 2, s1.k[:0], s2.k[:0], nil, s.ck, nil) copy(s1.k[:], hk1) copy(s2.k[:], hk2) s1.c = s.cs.Cipher(s1.k) @@ -180,6 +184,7 @@ const ( MessagePatternDHES MessagePatternDHSE MessagePatternDHSS + MessagePatternPSK ) // MaxMsgLen is the maximum number of bytes that can be sent in a single Noise @@ -194,6 +199,7 @@ type HandshakeState struct { e DHKey // local ephemeral keypair rs []byte // remote party's static public key re []byte // remote party's ephemeral public key + psk []byte // preshared key, maybe zero length messagePatterns [][]MessagePattern shouldWrite bool msgIdx int @@ -221,9 +227,13 @@ type Config struct { // be identical on both sides for the handshake to succeed. Prologue []byte - // PresharedKey is the optional pre-shared key for the handshake. + // PresharedKey is the optional preshared key for the handshake. PresharedKey []byte + // PresharedKeyPlacement specifies the placement position of the PSK token + // when PresharedKey is specified + PresharedKeyPlacement int + // StaticKeypair is this peer's static keypair, required if part of the // handshake. StaticKeypair DHKey @@ -247,6 +257,7 @@ func NewHandshakeState(c Config) *HandshakeState { s: c.StaticKeypair, e: c.EphemeralKeypair, rs: c.PeerStatic, + psk: c.PresharedKey, messagePatterns: c.Pattern.Messages, shouldWrite: c.Initiator, rng: c.Random, @@ -259,15 +270,21 @@ func NewHandshakeState(c Config) *HandshakeState { copy(hs.re, c.PeerEphemeral) } hs.ss.cs = c.CipherSuite - namePrefix := "Noise_" - if len(c.PresharedKey) > 0 { - namePrefix = "NoisePSK_" + pskModifier := "" + if len(hs.psk) > 0 { + if len(hs.psk) != 32 { + panic("noise: specification mandates 256-bit preshared keys") + } + pskModifier = fmt.Sprintf("psk%d", c.PresharedKeyPlacement) + hs.messagePatterns = append([][]MessagePattern(nil), hs.messagePatterns...) + if (c.PresharedKeyPlacement == 0) { + hs.messagePatterns[0] = append([]MessagePattern{MessagePatternPSK}, hs.messagePatterns[0]...) + } else { + hs.messagePatterns[c.PresharedKeyPlacement - 1] = append(hs.messagePatterns[c.PresharedKeyPlacement - 1], MessagePatternPSK) + } } - hs.ss.InitializeSymmetric([]byte(namePrefix + c.Pattern.Name + "_" + string(hs.ss.cs.Name()))) + hs.ss.InitializeSymmetric([]byte("Noise_" + c.Pattern.Name + pskModifier + "_" + string(hs.ss.cs.Name()))) hs.ss.MixHash(c.Prologue) - if len(c.PresharedKey) > 0 { - hs.ss.MixPresharedKey(c.PresharedKey) - } for _, m := range c.Pattern.InitiatorPreMessages { switch { case c.Initiator && m == MessagePatternS: @@ -318,7 +335,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState s.e = s.ss.cs.GenerateKeypair(s.rng) out = append(out, s.e.Public...) s.ss.MixHash(s.e.Public) - if s.ss.hasPSK { + if len(s.psk) > 0 { s.ss.MixKey(s.e.Public) } case MessagePatternS: @@ -334,6 +351,8 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.re)) case MessagePatternDHSS: s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs)) + case MessagePatternPSK: + s.ss.MixKeyAndHash(s.psk) } } s.shouldWrite = false @@ -385,7 +404,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, s.re = s.re[:s.ss.cs.DHLen()] copy(s.re, message) s.ss.MixHash(s.re) - if s.ss.hasPSK { + if len(s.psk) > 0 { s.ss.MixKey(s.re) } case MessagePatternS: @@ -407,6 +426,8 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.rs)) case MessagePatternDHSS: s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs)) + case MessagePatternPSK: + s.ss.MixKeyAndHash(s.psk) } } out, err = s.ss.DecryptAndHash(out, message)