Initial implementation

This commit is contained in:
Jonathan Rudenberg 2015-11-15 12:49:59 -05:00
commit 14af5e1fc8
6 changed files with 883 additions and 0 deletions

29
LICENSE Normal file
View File

@ -0,0 +1,29 @@
Flynn® is a trademark of Prime Directive, Inc.
Copyright (c) 2015 Prime Directive, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Prime Directive, Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

166
cipher_suite.go Normal file
View File

@ -0,0 +1,166 @@
package noise
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"encoding/binary"
"hash"
"io"
"github.com/devi/blake2/blake2b"
"github.com/devi/blake2/blake2s"
"github.com/devi/chap"
"golang.org/x/crypto/curve25519"
)
type DHKey struct {
Private []byte
Public []byte
}
type DHFunc interface {
GenerateKeypair(random io.Reader) DHKey
DH(privkey, pubkey []byte) []byte
DHLen() int
DHName() string
}
type HashFunc interface {
Hash() hash.Hash
HashName() string
}
type CipherFunc interface {
Cipher(k [32]byte) Cipher
CipherName() string
}
type Cipher interface {
Encrypt(out []byte, n uint64, ad, plaintext []byte) []byte
Decrypt(out []byte, n uint64, ad, ciphertext []byte) ([]byte, error)
}
type CipherSuite interface {
DHFunc
CipherFunc
HashFunc
Name() []byte
}
func NewCipherSuite(dh DHFunc, c CipherFunc, h HashFunc) CipherSuite {
return ciphersuite{
DHFunc: dh,
CipherFunc: c,
HashFunc: h,
name: []byte(dh.DHName() + "_" + c.CipherName() + "_" + h.HashName()),
}
}
type ciphersuite struct {
DHFunc
CipherFunc
HashFunc
name []byte
}
func (s ciphersuite) Name() []byte { return s.name }
var DH25519 DHFunc = dh25519{}
type dh25519 struct{}
func (dh25519) GenerateKeypair(rng io.Reader) DHKey {
var pubkey, privkey [32]byte
if rng == nil {
rng = rand.Reader
}
if _, err := io.ReadFull(rng, privkey[:]); err != nil {
panic(err)
}
curve25519.ScalarBaseMult(&pubkey, &privkey)
return DHKey{Private: privkey[:], Public: pubkey[:]}
}
func (dh25519) DH(privkey, pubkey []byte) []byte {
var dst, in, base [32]byte
copy(in[:], privkey)
copy(base[:], pubkey)
curve25519.ScalarMult(&dst, &in, &base)
return dst[:]
}
func (dh25519) DHLen() int { return 32 }
func (dh25519) DHName() string { return "25519" }
type cipherFn struct {
fn func([32]byte) Cipher
name string
}
func (c cipherFn) Cipher(k [32]byte) Cipher { return c.fn(k) }
func (c cipherFn) CipherName() string { return c.name }
var CipherAESGCM CipherFunc = cipherFn{
func(k [32]byte) Cipher {
c, err := aes.NewCipher(k[:])
if err != nil {
panic(err)
}
gcm, err := cipher.NewGCM(c)
if err != nil {
panic(err)
}
return aeadCipher{
gcm,
func(n uint64) []byte {
var nonce [12]byte
binary.BigEndian.PutUint64(nonce[4:], n)
return nonce[:]
},
}
},
"AESGCM",
}
var CipherChaChaPoly CipherFunc = cipherFn{
func(k [32]byte) Cipher {
return aeadCipher{
chap.NewCipher(&k),
func(n uint64) []byte {
var nonce [12]byte
binary.LittleEndian.PutUint64(nonce[4:], n)
return nonce[:]
},
}
},
"ChaChaPoly",
}
type aeadCipher struct {
cipher.AEAD
nonce func(uint64) []byte
}
func (c aeadCipher) Encrypt(out []byte, n uint64, ad, plaintext []byte) []byte {
return c.Seal(out, c.nonce(n), plaintext, ad)
}
func (c aeadCipher) Decrypt(out []byte, n uint64, ad, ciphertext []byte) ([]byte, error) {
return c.Open(out, c.nonce(n), ciphertext, ad)
}
type hashFn struct {
fn func() hash.Hash
name string
}
func (h hashFn) Hash() hash.Hash { return h.fn() }
func (h hashFn) HashName() string { return h.name }
var HashSHA256 HashFunc = hashFn{sha256.New, "SHA256"}
var HashSHA512 HashFunc = hashFn{sha512.New, "SHA512"}
var HashBLAKE2b HashFunc = hashFn{blake2b.New, "BLAKE2b"}
var HashBLAKE2s HashFunc = hashFn{blake2s.New, "BLAKE2s"}

30
hkdf.go Normal file
View File

@ -0,0 +1,30 @@
package noise
import (
"crypto/hmac"
"hash"
)
func HKDF(h func() hash.Hash, out1, out2, chainingKey, inputKeyMaterial []byte) ([]byte, []byte) {
if len(out1) > 0 {
panic("len(out1) > 0")
}
if len(out2) > 0 {
panic("len(out2) > 0")
}
tempMAC := hmac.New(h, chainingKey)
tempMAC.Write(inputKeyMaterial)
tempKey := tempMAC.Sum(out2)
out1MAC := hmac.New(h, tempKey)
out1MAC.Write([]byte{0x01})
out1 = out1MAC.Sum(out1)
out2MAC := hmac.New(h, tempKey)
out2MAC.Write(out1)
out2MAC.Write([]byte{0x02})
out2 = out2MAC.Sum(tempKey[:0])
return out1, out2
}

222
noise_test.go Normal file
View File

@ -0,0 +1,222 @@
package noise
import (
"encoding/hex"
"testing"
. "gopkg.in/check.v1"
)
func Test(t *testing.T) { TestingT(t) }
type NoiseSuite struct{}
var _ = Suite(&NoiseSuite{})
type RandomInc byte
func (r *RandomInc) Read(p []byte) (int, error) {
for i := range p {
p[i] = byte(*r)
*r = (*r) + 1
}
return len(p), nil
}
func (NoiseSuite) TestN(c *C) {
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256)
rng := new(RandomInc)
staticR := cs.GenerateKeypair(rng)
hs := NewHandshakeState(cs, rng, HandshakeN, true, nil, nil, nil, staticR.Public, nil)
hello, _, _ := hs.WriteMessage(nil, nil)
expected, _ := hex.DecodeString("358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd1662548331a3d1e93b490263abc7a4633867f4")
c.Assert(hello, DeepEquals, expected)
}
func (NoiseSuite) TestX(c *C) {
cs := NewCipherSuite(DH25519, CipherChaChaPoly, HashSHA256)
rng := new(RandomInc)
staticI := cs.GenerateKeypair(rng)
staticR := cs.GenerateKeypair(rng)
hs := NewHandshakeState(cs, rng, HandshakeX, true, nil, &staticI, nil, staticR.Public, nil)
hello, _, _ := hs.WriteMessage(nil, nil)
expected, _ := hex.DecodeString("79a631eede1bf9c98f12032cdeadd0e7a079398fc786b88cc846ec89af85a51ad203cd28d81cf65a2da637f557a05728b3ae4abdc3a42d1cda5f719d6cf41d7f2cf1b1c5af10e38a09a9bb7e3b1d589a99492cc50293eaa1f3f391b59bb6990d")
c.Assert(hello, DeepEquals, expected)
}
func (NoiseSuite) TestNN(c *C) {
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA512)
rngI := new(RandomInc)
rngR := new(RandomInc)
*rngR = 1
hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, nil, nil, nil, nil)
hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, nil, nil, nil, nil, nil)
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
c.Assert(msg, HasLen, 35)
res, _, _, err := hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "abc")
msg, _, _ = hsR.WriteMessage(nil, []byte("defg"))
c.Assert(msg, HasLen, 52)
res, _, _, err = hsI.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "defg")
expected, _ := hex.DecodeString("07a37cbc142093c8b755dc1b10e86cb426374ad16aa853ed0bdfc0b2b86d1c7c5e4dc9545d41b3280f4586a5481829e1e24ec5a0")
c.Assert(msg, DeepEquals, expected)
}
func (NoiseSuite) TestXX(c *C) {
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256)
rngI := new(RandomInc)
rngR := new(RandomInc)
*rngR = 1
staticI := cs.GenerateKeypair(rngI)
staticR := cs.GenerateKeypair(rngR)
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, &staticI, nil, nil, nil)
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, &staticR, nil, nil, nil)
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
c.Assert(msg, HasLen, 35)
res, _, _, err := hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "abc")
msg, _, _ = hsR.WriteMessage(nil, []byte("defg"))
c.Assert(msg, HasLen, 100)
res, _, _, err = hsI.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "defg")
msg, _, _ = hsI.WriteMessage(nil, nil)
c.Assert(msg, HasLen, 64)
res, _, _, err = hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(res, HasLen, 0)
expected, _ := hex.DecodeString("8127f4b35cdbdf0935fcf1ec99016d1dcbc350055b8af360be196905dfb50a2c1c38a7ca9cb0cfe8f4576f36c47a4933eee32288f590ac4305d4b53187577be7")
c.Assert(msg, DeepEquals, expected)
}
func (NoiseSuite) TestIK(c *C) {
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256)
rngI := new(RandomInc)
rngR := new(RandomInc)
*rngR = 1
staticI := cs.GenerateKeypair(rngI)
staticR := cs.GenerateKeypair(rngR)
hsI := NewHandshakeState(cs, rngI, HandshakeIK, true, []byte("ABC"), &staticI, nil, staticR.Public, nil)
hsR := NewHandshakeState(cs, rngR, HandshakeIK, false, []byte("ABC"), &staticR, nil, nil, nil)
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
c.Assert(msg, HasLen, 99)
res, _, _, err := hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "abc")
msg, _, _ = hsR.WriteMessage(nil, []byte("defg"))
c.Assert(msg, HasLen, 68)
res, _, _, err = hsI.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "defg")
expected, _ := hex.DecodeString("5a491c3d8524aee516e7edccba51433ebe651002f0f79fd79dc6a4bf65ecd7b13543f1cc7910a367ffc3686f9c03e62e7555a9411133bb3194f27a9433507b30d858d578")
c.Assert(msg, DeepEquals, expected)
}
func (NoiseSuite) TestXE(c *C) {
cs := NewCipherSuite(DH25519, CipherChaChaPoly, HashBLAKE2b)
rngI := new(RandomInc)
rngR := new(RandomInc)
*rngR = 1
staticI := cs.GenerateKeypair(rngI)
staticR := cs.GenerateKeypair(rngR)
ephR := cs.GenerateKeypair(rngR)
hsI := NewHandshakeState(cs, rngI, HandshakeXE, true, nil, &staticI, nil, staticR.Public, ephR.Public)
hsR := NewHandshakeState(cs, rngR, HandshakeXE, false, nil, &staticR, &ephR, nil, nil)
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
c.Assert(msg, HasLen, 51)
res, _, _, err := hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "abc")
msg, _, _ = hsR.WriteMessage(nil, []byte("defg"))
c.Assert(msg, HasLen, 68)
res, _, _, err = hsI.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "defg")
msg, _, _ = hsI.WriteMessage(nil, nil)
c.Assert(msg, HasLen, 64)
res, _, _, err = hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(res, HasLen, 0)
expected, _ := hex.DecodeString("08439f380b6f128a1465840d558f06abb1141cf5708a9dcf573d6e4fae01f90f7c9b8ef856bdc483df643a9d240ab6d38d9af9f3812ef44a465e32f8227a7c8b")
c.Assert(msg, DeepEquals, expected)
}
func (NoiseSuite) TestXXRoundtrip(c *C) {
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256)
rngI := new(RandomInc)
rngR := new(RandomInc)
*rngR = 1
staticI := cs.GenerateKeypair(rngI)
staticR := cs.GenerateKeypair(rngR)
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, &staticI, nil, nil, nil)
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, &staticR, nil, nil, nil)
// -> e
msg, _, _ := hsI.WriteMessage(nil, []byte("abcdef"))
c.Assert(msg, HasLen, 38)
res, _, _, err := hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "abcdef")
// <- e, dhee, s, dhse
msg, _, _ = hsR.WriteMessage(nil, nil)
c.Assert(msg, HasLen, 96)
res, _, _, err = hsI.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(res, HasLen, 0)
// -> s, dhse
payload := "0123456789012345678901234567890123456789012345678901234567890123456789"
msg, csI0, csI1 := hsI.WriteMessage(nil, []byte(payload))
c.Assert(msg, HasLen, 134)
res, csR0, csR1, err := hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, payload)
// transport message I -> R
msg = csI0.Encrypt(nil, nil, []byte("wubba"))
res, err = csR0.Decrypt(nil, nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "wubba")
// transport message I -> R again
msg = csI0.Encrypt(nil, nil, []byte("aleph"))
res, err = csR0.Decrypt(nil, nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "aleph")
// transport message R <- I
msg = csR1.Encrypt(nil, nil, []byte("worri"))
res, err = csI1.Decrypt(nil, nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "worri")
}

170
patterns.go Normal file
View File

@ -0,0 +1,170 @@
package noise
var HandshakeNN = HandshakePattern{
Name: "NN",
Messages: [][]MessagePattern{
{MessagePatternE},
{MessagePatternE, MessagePatternDHEE},
},
}
var HandshakeKN = HandshakePattern{
Name: "KN",
InitiatorPreMessages: []MessagePattern{MessagePatternS},
Messages: [][]MessagePattern{
{MessagePatternE},
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES},
},
}
var HandshakeNK = HandshakePattern{
Name: "NK",
ResponderPreMessages: []MessagePattern{MessagePatternS},
Messages: [][]MessagePattern{
{MessagePatternE, MessagePatternDHES},
{MessagePatternE, MessagePatternDHEE},
},
}
var HandshakeKK = HandshakePattern{
Name: "KK",
InitiatorPreMessages: []MessagePattern{MessagePatternS},
ResponderPreMessages: []MessagePattern{MessagePatternS},
Messages: [][]MessagePattern{
{MessagePatternE, MessagePatternDHES, MessagePatternDHSS},
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES},
},
}
var HandshakeNE = HandshakePattern{
Name: "NE",
ResponderPreMessages: []MessagePattern{MessagePatternS, MessagePatternE},
Messages: [][]MessagePattern{
{MessagePatternE, MessagePatternDHEE, MessagePatternDHSE},
{MessagePatternE, MessagePatternDHEE},
},
}
var HandshakeKE = HandshakePattern{
Name: "KE",
InitiatorPreMessages: []MessagePattern{MessagePatternS},
ResponderPreMessages: []MessagePattern{MessagePatternS, MessagePatternE},
Messages: [][]MessagePattern{
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES, MessagePatternDHSE},
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES},
},
}
var HandshakeNX = HandshakePattern{
Name: "NX",
Messages: [][]MessagePattern{
{MessagePatternE},
{MessagePatternE, MessagePatternDHEE, MessagePatternS, MessagePatternDHSE},
},
}
var HandshakeKX = HandshakePattern{
Name: "KX",
InitiatorPreMessages: []MessagePattern{MessagePatternS},
Messages: [][]MessagePattern{
{MessagePatternE},
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES, MessagePatternS, MessagePatternDHSE},
},
}
var HandshakeXN = HandshakePattern{
Name: "XN",
Messages: [][]MessagePattern{
{MessagePatternE},
{MessagePatternE, MessagePatternDHEE},
{MessagePatternS, MessagePatternDHSE},
},
}
var HandshakeIN = HandshakePattern{
Name: "IN",
Messages: [][]MessagePattern{
{MessagePatternS, MessagePatternE},
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES},
},
}
var HandshakeXK = HandshakePattern{
Name: "XK",
ResponderPreMessages: []MessagePattern{MessagePatternS},
Messages: [][]MessagePattern{
{MessagePatternE, MessagePatternDHES},
{MessagePatternE, MessagePatternDHEE},
{MessagePatternS, MessagePatternDHSE},
},
}
var HandshakeIK = HandshakePattern{
Name: "IK",
ResponderPreMessages: []MessagePattern{MessagePatternS},
Messages: [][]MessagePattern{
{MessagePatternE, MessagePatternDHES, MessagePatternS, MessagePatternDHSS},
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES},
},
}
var HandshakeXE = HandshakePattern{
Name: "XE",
ResponderPreMessages: []MessagePattern{MessagePatternS, MessagePatternE},
Messages: [][]MessagePattern{
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES},
{MessagePatternE, MessagePatternDHEE},
{MessagePatternS, MessagePatternDHSE},
},
}
var HandshakeIE = HandshakePattern{
Name: "IE",
ResponderPreMessages: []MessagePattern{MessagePatternS, MessagePatternE},
Messages: [][]MessagePattern{
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES, MessagePatternS, MessagePatternDHSE},
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES},
},
}
var HandshakeXX = HandshakePattern{
Name: "XX",
Messages: [][]MessagePattern{
{MessagePatternE},
{MessagePatternE, MessagePatternDHEE, MessagePatternS, MessagePatternDHSE},
{MessagePatternS, MessagePatternDHSE},
},
}
var HandshakeIX = HandshakePattern{
Name: "IX",
Messages: [][]MessagePattern{
{MessagePatternS, MessagePatternE},
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES, MessagePatternS, MessagePatternDHSE},
},
}
var HandshakeN = HandshakePattern{
Name: "N",
ResponderPreMessages: []MessagePattern{MessagePatternS},
Messages: [][]MessagePattern{
{MessagePatternE, MessagePatternDHES},
},
}
var HandshakeK = HandshakePattern{
Name: "K",
InitiatorPreMessages: []MessagePattern{MessagePatternS},
ResponderPreMessages: []MessagePattern{MessagePatternS},
Messages: [][]MessagePattern{
{MessagePatternE, MessagePatternDHES, MessagePatternDHSS},
},
}
var HandshakeX = HandshakePattern{
Name: "X",
ResponderPreMessages: []MessagePattern{MessagePatternS},
Messages: [][]MessagePattern{
{MessagePatternE, MessagePatternDHES, MessagePatternS, MessagePatternDHSS},
},
}

266
state.go Normal file
View File

@ -0,0 +1,266 @@
package noise
import (
"errors"
"io"
)
type CipherState struct {
cs CipherSuite
c Cipher
k [32]byte
n uint64
}
func (s *CipherState) Encrypt(out, ad, plaintext []byte) []byte {
out = s.c.Encrypt(out, s.n, ad, plaintext)
s.n++
return out
}
func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) {
out, err := s.c.Decrypt(out, s.n, ad, ciphertext)
s.n++
return out, err
}
type SymmetricState struct {
CipherState
hasKey bool
ck []byte
h []byte
}
func (s *SymmetricState) InitializeSymmetric(handshakeName []byte) {
h := s.cs.Hash()
if len(handshakeName) <= h.Size() {
s.h = make([]byte, h.Size())
copy(s.h, handshakeName)
} else {
h.Write(handshakeName)
s.h = h.Sum(nil)
}
s.ck = make([]byte, len(s.h))
copy(s.ck, s.h)
}
func (s *SymmetricState) MixKey(dhOutput []byte) {
s.n = 0
s.hasKey = true
var hk []byte
s.ck, hk = HKDF(s.cs.Hash, s.ck[:0], s.k[:0], s.ck, dhOutput)
copy(s.k[:], hk)
s.c = s.cs.Cipher(s.k)
}
func (s *SymmetricState) MixHash(data []byte) {
h := s.cs.Hash()
h.Write(s.h)
h.Write(data)
s.h = h.Sum(s.h[:0])
}
func (s *SymmetricState) EncryptAndHash(out, plaintext []byte) []byte {
if !s.hasKey {
s.MixHash(plaintext)
return append(out, plaintext...)
}
ciphertext := s.Encrypt(out, s.h, plaintext)
s.MixHash(ciphertext[len(out):])
return ciphertext
}
func (s *SymmetricState) DecryptAndHash(out, data []byte) ([]byte, error) {
if !s.hasKey {
s.MixHash(data)
return append(out, data...), nil
}
plaintext, err := s.Decrypt(out, s.h, data)
if err != nil {
return nil, err
}
s.MixHash(data)
return plaintext, nil
}
func (s *SymmetricState) Split() (*CipherState, *CipherState) {
s1, s2 := &CipherState{cs: s.cs}, &CipherState{cs: s.cs}
hk1, hk2 := HKDF(s.cs.Hash, s1.k[:0], s2.k[:0], s.ck, nil)
copy(s1.k[:], hk1)
copy(s2.k[:], hk2)
s1.c = s.cs.Cipher(s1.k)
s2.c = s.cs.Cipher(s2.k)
return s1, s2
}
type MessagePattern int
type HandshakePattern struct {
Name string
InitiatorPreMessages []MessagePattern
ResponderPreMessages []MessagePattern
Messages [][]MessagePattern
}
const (
MessagePatternS MessagePattern = iota
MessagePatternE
MessagePatternDHEE
MessagePatternDHES
MessagePatternDHSE
MessagePatternDHSS
)
const MaxMsgLen = 65535
type HandshakeState struct {
SymmetricState
s DHKey // local static keypair
e DHKey // local ephemeral keypair
rs []byte // remote party's static public key
re []byte // remote party's ephemeral public key
messagePatterns [][]MessagePattern
shouldWrite bool
msgIdx int
rng io.Reader
}
func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern HandshakePattern, initiator bool, prologue []byte, newS, newE *DHKey, newRS, newRE []byte) *HandshakeState {
hs := &HandshakeState{
rs: newRS,
re: newRE,
messagePatterns: newHandshakePattern.Messages,
shouldWrite: initiator,
rng: rng,
}
hs.SymmetricState.cs = cs
if newE != nil {
hs.e = *newE
}
if newS != nil {
hs.s = *newS
}
hs.InitializeSymmetric([]byte("Noise_" + newHandshakePattern.Name + "_" + string(cs.Name())))
hs.MixHash(prologue)
for _, m := range newHandshakePattern.InitiatorPreMessages {
switch {
case initiator && m == MessagePatternS:
hs.MixHash(newS.Public)
case initiator && m == MessagePatternE:
hs.MixHash(newE.Public)
case !initiator && m == MessagePatternS:
hs.MixHash(newRS)
case !initiator && m == MessagePatternE:
hs.MixHash(newRE)
}
}
for _, m := range newHandshakePattern.ResponderPreMessages {
switch {
case !initiator && m == MessagePatternS:
hs.MixHash(newS.Public)
case !initiator && m == MessagePatternE:
hs.MixHash(newE.Public)
case initiator && m == MessagePatternS:
hs.MixHash(newRS)
case initiator && m == MessagePatternE:
hs.MixHash(newRE)
}
}
return hs
}
func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState) {
if !s.shouldWrite {
panic("noise: unexpected call to WriteMessage should be ReadMessage")
}
if s.msgIdx > len(s.messagePatterns)-1 {
panic("noise: no handshake messages left")
}
if len(payload) > MaxMsgLen {
panic("noise: message is too long")
}
for _, msg := range s.messagePatterns[s.msgIdx] {
switch msg {
case MessagePatternE:
s.e = s.cs.GenerateKeypair(s.rng)
out = s.EncryptAndHash(out, s.e.Public)
case MessagePatternS:
out = s.EncryptAndHash(out, s.s.Public)
case MessagePatternDHEE:
s.MixKey(s.cs.DH(s.e.Private, s.re))
case MessagePatternDHES:
s.MixKey(s.cs.DH(s.e.Private, s.rs))
case MessagePatternDHSE:
s.MixKey(s.cs.DH(s.s.Private, s.re))
case MessagePatternDHSS:
s.MixKey(s.cs.DH(s.s.Private, s.rs))
}
}
s.shouldWrite = false
s.msgIdx++
out = s.EncryptAndHash(out, payload)
if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.Split()
return out, cs1, cs2
}
return out, nil, nil
}
var ErrShortMessage = errors.New("noise: message is too short")
func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, *CipherState, error) {
if s.shouldWrite {
panic("noise: unexpected call to ReadMessage should be WriteMessage")
}
if s.msgIdx > len(s.messagePatterns)-1 {
panic("noise: no handshake messages left")
}
var err error
for _, msg := range s.messagePatterns[s.msgIdx] {
switch msg {
case MessagePatternE, MessagePatternS:
expected := s.cs.DHLen()
if s.hasKey {
expected += 16
}
if len(message) < expected {
return nil, nil, nil, ErrShortMessage
}
switch msg {
case MessagePatternE:
s.re, err = s.DecryptAndHash(s.re[:0], message[:expected])
case MessagePatternS:
s.rs, err = s.DecryptAndHash(s.rs[:0], message[:expected])
}
if err != nil {
return nil, nil, nil, err
}
message = message[expected:]
case MessagePatternDHEE:
s.MixKey(s.cs.DH(s.e.Private, s.re))
case MessagePatternDHES:
s.MixKey(s.cs.DH(s.s.Private, s.re))
case MessagePatternDHSE:
s.MixKey(s.cs.DH(s.e.Private, s.rs))
case MessagePatternDHSS:
s.MixKey(s.cs.DH(s.s.Private, s.rs))
}
}
s.shouldWrite = true
s.msgIdx++
out, err = s.DecryptAndHash(out, message)
if err != nil {
return nil, nil, nil, err
}
if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.Split()
return out, cs1, cs2, nil
}
return out, nil, nil, nil
}