commit 14af5e1fc86f8293637b4d486b7d94bb2291ad4a Author: Jonathan Rudenberg Date: Sun Nov 15 12:49:59 2015 -0500 Initial implementation diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..c1398eb --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +Flynn® is a trademark of Prime Directive, Inc. + +Copyright (c) 2015 Prime Directive, Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Prime Directive, Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/cipher_suite.go b/cipher_suite.go new file mode 100644 index 0000000..8a7f24d --- /dev/null +++ b/cipher_suite.go @@ -0,0 +1,166 @@ +package noise + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "crypto/sha512" + "encoding/binary" + "hash" + "io" + + "github.com/devi/blake2/blake2b" + "github.com/devi/blake2/blake2s" + "github.com/devi/chap" + "golang.org/x/crypto/curve25519" +) + +type DHKey struct { + Private []byte + Public []byte +} + +type DHFunc interface { + GenerateKeypair(random io.Reader) DHKey + DH(privkey, pubkey []byte) []byte + DHLen() int + DHName() string +} + +type HashFunc interface { + Hash() hash.Hash + HashName() string +} + +type CipherFunc interface { + Cipher(k [32]byte) Cipher + CipherName() string +} + +type Cipher interface { + Encrypt(out []byte, n uint64, ad, plaintext []byte) []byte + Decrypt(out []byte, n uint64, ad, ciphertext []byte) ([]byte, error) +} + +type CipherSuite interface { + DHFunc + CipherFunc + HashFunc + Name() []byte +} + +func NewCipherSuite(dh DHFunc, c CipherFunc, h HashFunc) CipherSuite { + return ciphersuite{ + DHFunc: dh, + CipherFunc: c, + HashFunc: h, + name: []byte(dh.DHName() + "_" + c.CipherName() + "_" + h.HashName()), + } +} + +type ciphersuite struct { + DHFunc + CipherFunc + HashFunc + name []byte +} + +func (s ciphersuite) Name() []byte { return s.name } + +var DH25519 DHFunc = dh25519{} + +type dh25519 struct{} + +func (dh25519) GenerateKeypair(rng io.Reader) DHKey { + var pubkey, privkey [32]byte + if rng == nil { + rng = rand.Reader + } + if _, err := io.ReadFull(rng, privkey[:]); err != nil { + panic(err) + } + curve25519.ScalarBaseMult(&pubkey, &privkey) + return DHKey{Private: privkey[:], Public: pubkey[:]} +} + +func (dh25519) DH(privkey, pubkey []byte) []byte { + var dst, in, base [32]byte + copy(in[:], privkey) + copy(base[:], pubkey) + curve25519.ScalarMult(&dst, &in, &base) + return dst[:] +} + +func (dh25519) DHLen() int { return 32 } +func (dh25519) DHName() string { return "25519" } + +type cipherFn struct { + fn func([32]byte) Cipher + name string +} + +func (c cipherFn) Cipher(k [32]byte) Cipher { return c.fn(k) } +func (c cipherFn) CipherName() string { return c.name } + +var CipherAESGCM CipherFunc = cipherFn{ + func(k [32]byte) Cipher { + c, err := aes.NewCipher(k[:]) + if err != nil { + panic(err) + } + gcm, err := cipher.NewGCM(c) + if err != nil { + panic(err) + } + return aeadCipher{ + gcm, + func(n uint64) []byte { + var nonce [12]byte + binary.BigEndian.PutUint64(nonce[4:], n) + return nonce[:] + }, + } + }, + "AESGCM", +} + +var CipherChaChaPoly CipherFunc = cipherFn{ + func(k [32]byte) Cipher { + return aeadCipher{ + chap.NewCipher(&k), + func(n uint64) []byte { + var nonce [12]byte + binary.LittleEndian.PutUint64(nonce[4:], n) + return nonce[:] + }, + } + }, + "ChaChaPoly", +} + +type aeadCipher struct { + cipher.AEAD + nonce func(uint64) []byte +} + +func (c aeadCipher) Encrypt(out []byte, n uint64, ad, plaintext []byte) []byte { + return c.Seal(out, c.nonce(n), plaintext, ad) +} + +func (c aeadCipher) Decrypt(out []byte, n uint64, ad, ciphertext []byte) ([]byte, error) { + return c.Open(out, c.nonce(n), ciphertext, ad) +} + +type hashFn struct { + fn func() hash.Hash + name string +} + +func (h hashFn) Hash() hash.Hash { return h.fn() } +func (h hashFn) HashName() string { return h.name } + +var HashSHA256 HashFunc = hashFn{sha256.New, "SHA256"} +var HashSHA512 HashFunc = hashFn{sha512.New, "SHA512"} +var HashBLAKE2b HashFunc = hashFn{blake2b.New, "BLAKE2b"} +var HashBLAKE2s HashFunc = hashFn{blake2s.New, "BLAKE2s"} diff --git a/hkdf.go b/hkdf.go new file mode 100644 index 0000000..f318bcf --- /dev/null +++ b/hkdf.go @@ -0,0 +1,30 @@ +package noise + +import ( + "crypto/hmac" + "hash" +) + +func HKDF(h func() hash.Hash, out1, out2, chainingKey, inputKeyMaterial []byte) ([]byte, []byte) { + if len(out1) > 0 { + panic("len(out1) > 0") + } + if len(out2) > 0 { + panic("len(out2) > 0") + } + + tempMAC := hmac.New(h, chainingKey) + tempMAC.Write(inputKeyMaterial) + tempKey := tempMAC.Sum(out2) + + out1MAC := hmac.New(h, tempKey) + out1MAC.Write([]byte{0x01}) + out1 = out1MAC.Sum(out1) + + out2MAC := hmac.New(h, tempKey) + out2MAC.Write(out1) + out2MAC.Write([]byte{0x02}) + out2 = out2MAC.Sum(tempKey[:0]) + + return out1, out2 +} diff --git a/noise_test.go b/noise_test.go new file mode 100644 index 0000000..9634cfa --- /dev/null +++ b/noise_test.go @@ -0,0 +1,222 @@ +package noise + +import ( + "encoding/hex" + "testing" + + . "gopkg.in/check.v1" +) + +func Test(t *testing.T) { TestingT(t) } + +type NoiseSuite struct{} + +var _ = Suite(&NoiseSuite{}) + +type RandomInc byte + +func (r *RandomInc) Read(p []byte) (int, error) { + for i := range p { + p[i] = byte(*r) + *r = (*r) + 1 + } + return len(p), nil +} + +func (NoiseSuite) TestN(c *C) { + cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256) + rng := new(RandomInc) + staticR := cs.GenerateKeypair(rng) + hs := NewHandshakeState(cs, rng, HandshakeN, true, nil, nil, nil, staticR.Public, nil) + + hello, _, _ := hs.WriteMessage(nil, nil) + expected, _ := hex.DecodeString("358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd1662548331a3d1e93b490263abc7a4633867f4") + c.Assert(hello, DeepEquals, expected) +} + +func (NoiseSuite) TestX(c *C) { + cs := NewCipherSuite(DH25519, CipherChaChaPoly, HashSHA256) + rng := new(RandomInc) + staticI := cs.GenerateKeypair(rng) + staticR := cs.GenerateKeypair(rng) + hs := NewHandshakeState(cs, rng, HandshakeX, true, nil, &staticI, nil, staticR.Public, nil) + + hello, _, _ := hs.WriteMessage(nil, nil) + expected, _ := hex.DecodeString("79a631eede1bf9c98f12032cdeadd0e7a079398fc786b88cc846ec89af85a51ad203cd28d81cf65a2da637f557a05728b3ae4abdc3a42d1cda5f719d6cf41d7f2cf1b1c5af10e38a09a9bb7e3b1d589a99492cc50293eaa1f3f391b59bb6990d") + c.Assert(hello, DeepEquals, expected) +} + +func (NoiseSuite) TestNN(c *C) { + cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA512) + rngI := new(RandomInc) + rngR := new(RandomInc) + *rngR = 1 + + hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, nil, nil, nil, nil) + hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, nil, nil, nil, nil, nil) + + msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) + c.Assert(msg, HasLen, 35) + res, _, _, err := hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "abc") + + msg, _, _ = hsR.WriteMessage(nil, []byte("defg")) + c.Assert(msg, HasLen, 52) + res, _, _, err = hsI.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "defg") + + expected, _ := hex.DecodeString("07a37cbc142093c8b755dc1b10e86cb426374ad16aa853ed0bdfc0b2b86d1c7c5e4dc9545d41b3280f4586a5481829e1e24ec5a0") + c.Assert(msg, DeepEquals, expected) +} + +func (NoiseSuite) TestXX(c *C) { + cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256) + rngI := new(RandomInc) + rngR := new(RandomInc) + *rngR = 1 + + staticI := cs.GenerateKeypair(rngI) + staticR := cs.GenerateKeypair(rngR) + + hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, &staticI, nil, nil, nil) + hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, &staticR, nil, nil, nil) + + msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) + c.Assert(msg, HasLen, 35) + res, _, _, err := hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "abc") + + msg, _, _ = hsR.WriteMessage(nil, []byte("defg")) + c.Assert(msg, HasLen, 100) + res, _, _, err = hsI.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "defg") + + msg, _, _ = hsI.WriteMessage(nil, nil) + c.Assert(msg, HasLen, 64) + res, _, _, err = hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(res, HasLen, 0) + + expected, _ := hex.DecodeString("8127f4b35cdbdf0935fcf1ec99016d1dcbc350055b8af360be196905dfb50a2c1c38a7ca9cb0cfe8f4576f36c47a4933eee32288f590ac4305d4b53187577be7") + c.Assert(msg, DeepEquals, expected) +} + +func (NoiseSuite) TestIK(c *C) { + cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256) + rngI := new(RandomInc) + rngR := new(RandomInc) + *rngR = 1 + + staticI := cs.GenerateKeypair(rngI) + staticR := cs.GenerateKeypair(rngR) + + hsI := NewHandshakeState(cs, rngI, HandshakeIK, true, []byte("ABC"), &staticI, nil, staticR.Public, nil) + hsR := NewHandshakeState(cs, rngR, HandshakeIK, false, []byte("ABC"), &staticR, nil, nil, nil) + + msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) + c.Assert(msg, HasLen, 99) + res, _, _, err := hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "abc") + + msg, _, _ = hsR.WriteMessage(nil, []byte("defg")) + c.Assert(msg, HasLen, 68) + res, _, _, err = hsI.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "defg") + + expected, _ := hex.DecodeString("5a491c3d8524aee516e7edccba51433ebe651002f0f79fd79dc6a4bf65ecd7b13543f1cc7910a367ffc3686f9c03e62e7555a9411133bb3194f27a9433507b30d858d578") + c.Assert(msg, DeepEquals, expected) +} + +func (NoiseSuite) TestXE(c *C) { + cs := NewCipherSuite(DH25519, CipherChaChaPoly, HashBLAKE2b) + rngI := new(RandomInc) + rngR := new(RandomInc) + *rngR = 1 + + staticI := cs.GenerateKeypair(rngI) + staticR := cs.GenerateKeypair(rngR) + ephR := cs.GenerateKeypair(rngR) + + hsI := NewHandshakeState(cs, rngI, HandshakeXE, true, nil, &staticI, nil, staticR.Public, ephR.Public) + hsR := NewHandshakeState(cs, rngR, HandshakeXE, false, nil, &staticR, &ephR, nil, nil) + + msg, _, _ := hsI.WriteMessage(nil, []byte("abc")) + c.Assert(msg, HasLen, 51) + res, _, _, err := hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "abc") + + msg, _, _ = hsR.WriteMessage(nil, []byte("defg")) + c.Assert(msg, HasLen, 68) + res, _, _, err = hsI.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "defg") + + msg, _, _ = hsI.WriteMessage(nil, nil) + c.Assert(msg, HasLen, 64) + res, _, _, err = hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(res, HasLen, 0) + + expected, _ := hex.DecodeString("08439f380b6f128a1465840d558f06abb1141cf5708a9dcf573d6e4fae01f90f7c9b8ef856bdc483df643a9d240ab6d38d9af9f3812ef44a465e32f8227a7c8b") + c.Assert(msg, DeepEquals, expected) +} + +func (NoiseSuite) TestXXRoundtrip(c *C) { + cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256) + rngI := new(RandomInc) + rngR := new(RandomInc) + *rngR = 1 + + staticI := cs.GenerateKeypair(rngI) + staticR := cs.GenerateKeypair(rngR) + + hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, &staticI, nil, nil, nil) + hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, &staticR, nil, nil, nil) + + // -> e + msg, _, _ := hsI.WriteMessage(nil, []byte("abcdef")) + c.Assert(msg, HasLen, 38) + res, _, _, err := hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "abcdef") + + // <- e, dhee, s, dhse + msg, _, _ = hsR.WriteMessage(nil, nil) + c.Assert(msg, HasLen, 96) + res, _, _, err = hsI.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(res, HasLen, 0) + + // -> s, dhse + payload := "0123456789012345678901234567890123456789012345678901234567890123456789" + msg, csI0, csI1 := hsI.WriteMessage(nil, []byte(payload)) + c.Assert(msg, HasLen, 134) + res, csR0, csR1, err := hsR.ReadMessage(nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, payload) + + // transport message I -> R + msg = csI0.Encrypt(nil, nil, []byte("wubba")) + res, err = csR0.Decrypt(nil, nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "wubba") + + // transport message I -> R again + msg = csI0.Encrypt(nil, nil, []byte("aleph")) + res, err = csR0.Decrypt(nil, nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "aleph") + + // transport message R <- I + msg = csR1.Encrypt(nil, nil, []byte("worri")) + res, err = csI1.Decrypt(nil, nil, msg) + c.Assert(err, IsNil) + c.Assert(string(res), Equals, "worri") +} diff --git a/patterns.go b/patterns.go new file mode 100644 index 0000000..b9d8144 --- /dev/null +++ b/patterns.go @@ -0,0 +1,170 @@ +package noise + +var HandshakeNN = HandshakePattern{ + Name: "NN", + Messages: [][]MessagePattern{ + {MessagePatternE}, + {MessagePatternE, MessagePatternDHEE}, + }, +} + +var HandshakeKN = HandshakePattern{ + Name: "KN", + InitiatorPreMessages: []MessagePattern{MessagePatternS}, + Messages: [][]MessagePattern{ + {MessagePatternE}, + {MessagePatternE, MessagePatternDHEE, MessagePatternDHES}, + }, +} + +var HandshakeNK = HandshakePattern{ + Name: "NK", + ResponderPreMessages: []MessagePattern{MessagePatternS}, + Messages: [][]MessagePattern{ + {MessagePatternE, MessagePatternDHES}, + {MessagePatternE, MessagePatternDHEE}, + }, +} + +var HandshakeKK = HandshakePattern{ + Name: "KK", + InitiatorPreMessages: []MessagePattern{MessagePatternS}, + ResponderPreMessages: []MessagePattern{MessagePatternS}, + Messages: [][]MessagePattern{ + {MessagePatternE, MessagePatternDHES, MessagePatternDHSS}, + {MessagePatternE, MessagePatternDHEE, MessagePatternDHES}, + }, +} + +var HandshakeNE = HandshakePattern{ + Name: "NE", + ResponderPreMessages: []MessagePattern{MessagePatternS, MessagePatternE}, + Messages: [][]MessagePattern{ + {MessagePatternE, MessagePatternDHEE, MessagePatternDHSE}, + {MessagePatternE, MessagePatternDHEE}, + }, +} + +var HandshakeKE = HandshakePattern{ + Name: "KE", + InitiatorPreMessages: []MessagePattern{MessagePatternS}, + ResponderPreMessages: []MessagePattern{MessagePatternS, MessagePatternE}, + Messages: [][]MessagePattern{ + {MessagePatternE, MessagePatternDHEE, MessagePatternDHES, MessagePatternDHSE}, + {MessagePatternE, MessagePatternDHEE, MessagePatternDHES}, + }, +} + +var HandshakeNX = HandshakePattern{ + Name: "NX", + Messages: [][]MessagePattern{ + {MessagePatternE}, + {MessagePatternE, MessagePatternDHEE, MessagePatternS, MessagePatternDHSE}, + }, +} + +var HandshakeKX = HandshakePattern{ + Name: "KX", + InitiatorPreMessages: []MessagePattern{MessagePatternS}, + Messages: [][]MessagePattern{ + {MessagePatternE}, + {MessagePatternE, MessagePatternDHEE, MessagePatternDHES, MessagePatternS, MessagePatternDHSE}, + }, +} + +var HandshakeXN = HandshakePattern{ + Name: "XN", + Messages: [][]MessagePattern{ + {MessagePatternE}, + {MessagePatternE, MessagePatternDHEE}, + {MessagePatternS, MessagePatternDHSE}, + }, +} + +var HandshakeIN = HandshakePattern{ + Name: "IN", + Messages: [][]MessagePattern{ + {MessagePatternS, MessagePatternE}, + {MessagePatternE, MessagePatternDHEE, MessagePatternDHES}, + }, +} + +var HandshakeXK = HandshakePattern{ + Name: "XK", + ResponderPreMessages: []MessagePattern{MessagePatternS}, + Messages: [][]MessagePattern{ + {MessagePatternE, MessagePatternDHES}, + {MessagePatternE, MessagePatternDHEE}, + {MessagePatternS, MessagePatternDHSE}, + }, +} + +var HandshakeIK = HandshakePattern{ + Name: "IK", + ResponderPreMessages: []MessagePattern{MessagePatternS}, + Messages: [][]MessagePattern{ + {MessagePatternE, MessagePatternDHES, MessagePatternS, MessagePatternDHSS}, + {MessagePatternE, MessagePatternDHEE, MessagePatternDHES}, + }, +} + +var HandshakeXE = HandshakePattern{ + Name: "XE", + ResponderPreMessages: []MessagePattern{MessagePatternS, MessagePatternE}, + Messages: [][]MessagePattern{ + {MessagePatternE, MessagePatternDHEE, MessagePatternDHES}, + {MessagePatternE, MessagePatternDHEE}, + {MessagePatternS, MessagePatternDHSE}, + }, +} + +var HandshakeIE = HandshakePattern{ + Name: "IE", + ResponderPreMessages: []MessagePattern{MessagePatternS, MessagePatternE}, + Messages: [][]MessagePattern{ + {MessagePatternE, MessagePatternDHEE, MessagePatternDHES, MessagePatternS, MessagePatternDHSE}, + {MessagePatternE, MessagePatternDHEE, MessagePatternDHES}, + }, +} + +var HandshakeXX = HandshakePattern{ + Name: "XX", + Messages: [][]MessagePattern{ + {MessagePatternE}, + {MessagePatternE, MessagePatternDHEE, MessagePatternS, MessagePatternDHSE}, + {MessagePatternS, MessagePatternDHSE}, + }, +} + +var HandshakeIX = HandshakePattern{ + Name: "IX", + Messages: [][]MessagePattern{ + {MessagePatternS, MessagePatternE}, + {MessagePatternE, MessagePatternDHEE, MessagePatternDHES, MessagePatternS, MessagePatternDHSE}, + }, +} + +var HandshakeN = HandshakePattern{ + Name: "N", + ResponderPreMessages: []MessagePattern{MessagePatternS}, + Messages: [][]MessagePattern{ + {MessagePatternE, MessagePatternDHES}, + }, +} + +var HandshakeK = HandshakePattern{ + Name: "K", + InitiatorPreMessages: []MessagePattern{MessagePatternS}, + ResponderPreMessages: []MessagePattern{MessagePatternS}, + Messages: [][]MessagePattern{ + {MessagePatternE, MessagePatternDHES, MessagePatternDHSS}, + }, +} + +var HandshakeX = HandshakePattern{ + Name: "X", + ResponderPreMessages: []MessagePattern{MessagePatternS}, + Messages: [][]MessagePattern{ + {MessagePatternE, MessagePatternDHES, MessagePatternS, MessagePatternDHSS}, + }, +} diff --git a/state.go b/state.go new file mode 100644 index 0000000..b2d8afe --- /dev/null +++ b/state.go @@ -0,0 +1,266 @@ +package noise + +import ( + "errors" + "io" +) + +type CipherState struct { + cs CipherSuite + c Cipher + k [32]byte + n uint64 +} + +func (s *CipherState) Encrypt(out, ad, plaintext []byte) []byte { + out = s.c.Encrypt(out, s.n, ad, plaintext) + s.n++ + return out +} + +func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) { + out, err := s.c.Decrypt(out, s.n, ad, ciphertext) + s.n++ + return out, err +} + +type SymmetricState struct { + CipherState + hasKey bool + ck []byte + h []byte +} + +func (s *SymmetricState) InitializeSymmetric(handshakeName []byte) { + h := s.cs.Hash() + if len(handshakeName) <= h.Size() { + s.h = make([]byte, h.Size()) + copy(s.h, handshakeName) + } else { + h.Write(handshakeName) + s.h = h.Sum(nil) + } + s.ck = make([]byte, len(s.h)) + copy(s.ck, s.h) +} + +func (s *SymmetricState) MixKey(dhOutput []byte) { + s.n = 0 + s.hasKey = true + var hk []byte + s.ck, hk = HKDF(s.cs.Hash, s.ck[:0], s.k[:0], s.ck, dhOutput) + copy(s.k[:], hk) + s.c = s.cs.Cipher(s.k) +} + +func (s *SymmetricState) MixHash(data []byte) { + h := s.cs.Hash() + h.Write(s.h) + h.Write(data) + s.h = h.Sum(s.h[:0]) +} + +func (s *SymmetricState) EncryptAndHash(out, plaintext []byte) []byte { + if !s.hasKey { + s.MixHash(plaintext) + return append(out, plaintext...) + } + ciphertext := s.Encrypt(out, s.h, plaintext) + s.MixHash(ciphertext[len(out):]) + return ciphertext +} + +func (s *SymmetricState) DecryptAndHash(out, data []byte) ([]byte, error) { + if !s.hasKey { + s.MixHash(data) + return append(out, data...), nil + } + plaintext, err := s.Decrypt(out, s.h, data) + if err != nil { + return nil, err + } + s.MixHash(data) + return plaintext, nil +} + +func (s *SymmetricState) Split() (*CipherState, *CipherState) { + s1, s2 := &CipherState{cs: s.cs}, &CipherState{cs: s.cs} + hk1, hk2 := HKDF(s.cs.Hash, s1.k[:0], s2.k[:0], s.ck, nil) + copy(s1.k[:], hk1) + copy(s2.k[:], hk2) + s1.c = s.cs.Cipher(s1.k) + s2.c = s.cs.Cipher(s2.k) + return s1, s2 +} + +type MessagePattern int + +type HandshakePattern struct { + Name string + InitiatorPreMessages []MessagePattern + ResponderPreMessages []MessagePattern + Messages [][]MessagePattern +} + +const ( + MessagePatternS MessagePattern = iota + MessagePatternE + MessagePatternDHEE + MessagePatternDHES + MessagePatternDHSE + MessagePatternDHSS +) + +const MaxMsgLen = 65535 + +type HandshakeState struct { + SymmetricState + s DHKey // local static keypair + e DHKey // local ephemeral keypair + rs []byte // remote party's static public key + re []byte // remote party's ephemeral public key + messagePatterns [][]MessagePattern + shouldWrite bool + msgIdx int + rng io.Reader +} + +func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern HandshakePattern, initiator bool, prologue []byte, newS, newE *DHKey, newRS, newRE []byte) *HandshakeState { + hs := &HandshakeState{ + rs: newRS, + re: newRE, + messagePatterns: newHandshakePattern.Messages, + shouldWrite: initiator, + rng: rng, + } + hs.SymmetricState.cs = cs + if newE != nil { + hs.e = *newE + } + if newS != nil { + hs.s = *newS + } + hs.InitializeSymmetric([]byte("Noise_" + newHandshakePattern.Name + "_" + string(cs.Name()))) + hs.MixHash(prologue) + for _, m := range newHandshakePattern.InitiatorPreMessages { + switch { + case initiator && m == MessagePatternS: + hs.MixHash(newS.Public) + case initiator && m == MessagePatternE: + hs.MixHash(newE.Public) + case !initiator && m == MessagePatternS: + hs.MixHash(newRS) + case !initiator && m == MessagePatternE: + hs.MixHash(newRE) + } + } + for _, m := range newHandshakePattern.ResponderPreMessages { + switch { + case !initiator && m == MessagePatternS: + hs.MixHash(newS.Public) + case !initiator && m == MessagePatternE: + hs.MixHash(newE.Public) + case initiator && m == MessagePatternS: + hs.MixHash(newRS) + case initiator && m == MessagePatternE: + hs.MixHash(newRE) + } + } + return hs +} + +func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState) { + if !s.shouldWrite { + panic("noise: unexpected call to WriteMessage should be ReadMessage") + } + if s.msgIdx > len(s.messagePatterns)-1 { + panic("noise: no handshake messages left") + } + if len(payload) > MaxMsgLen { + panic("noise: message is too long") + } + + for _, msg := range s.messagePatterns[s.msgIdx] { + switch msg { + case MessagePatternE: + s.e = s.cs.GenerateKeypair(s.rng) + out = s.EncryptAndHash(out, s.e.Public) + case MessagePatternS: + out = s.EncryptAndHash(out, s.s.Public) + case MessagePatternDHEE: + s.MixKey(s.cs.DH(s.e.Private, s.re)) + case MessagePatternDHES: + s.MixKey(s.cs.DH(s.e.Private, s.rs)) + case MessagePatternDHSE: + s.MixKey(s.cs.DH(s.s.Private, s.re)) + case MessagePatternDHSS: + s.MixKey(s.cs.DH(s.s.Private, s.rs)) + } + } + s.shouldWrite = false + s.msgIdx++ + out = s.EncryptAndHash(out, payload) + + if s.msgIdx >= len(s.messagePatterns) { + cs1, cs2 := s.Split() + return out, cs1, cs2 + } + + return out, nil, nil +} + +var ErrShortMessage = errors.New("noise: message is too short") + +func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, *CipherState, error) { + if s.shouldWrite { + panic("noise: unexpected call to ReadMessage should be WriteMessage") + } + if s.msgIdx > len(s.messagePatterns)-1 { + panic("noise: no handshake messages left") + } + + var err error + for _, msg := range s.messagePatterns[s.msgIdx] { + switch msg { + case MessagePatternE, MessagePatternS: + expected := s.cs.DHLen() + if s.hasKey { + expected += 16 + } + if len(message) < expected { + return nil, nil, nil, ErrShortMessage + } + switch msg { + case MessagePatternE: + s.re, err = s.DecryptAndHash(s.re[:0], message[:expected]) + case MessagePatternS: + s.rs, err = s.DecryptAndHash(s.rs[:0], message[:expected]) + } + if err != nil { + return nil, nil, nil, err + } + message = message[expected:] + case MessagePatternDHEE: + s.MixKey(s.cs.DH(s.e.Private, s.re)) + case MessagePatternDHES: + s.MixKey(s.cs.DH(s.s.Private, s.re)) + case MessagePatternDHSE: + s.MixKey(s.cs.DH(s.e.Private, s.rs)) + case MessagePatternDHSS: + s.MixKey(s.cs.DH(s.s.Private, s.rs)) + } + } + s.shouldWrite = true + s.msgIdx++ + out, err = s.DecryptAndHash(out, message) + if err != nil { + return nil, nil, nil, err + } + + if s.msgIdx >= len(s.messagePatterns) { + cs1, cs2 := s.Split() + return out, cs1, cs2, nil + } + + return out, nil, nil, nil +}