2023-01-10 09:02:37 -04:00

601 lines
17 KiB
Go

// Package noise implements the Noise Protocol Framework.
//
// Noise is a low-level framework for building crypto protocols. Noise protocols
// support mutual and optional authentication, identity hiding, forward secrecy,
// zero round-trip encryption, and other advanced features. For more details,
// visit https://noiseprotocol.org.
package noise
import (
"crypto/rand"
"errors"
"fmt"
"io"
"math"
)
// A CipherState provides symmetric encryption and decryption after a successful
// handshake.
type CipherState struct {
cs CipherSuite
c Cipher
k [32]byte
n uint64
invalid bool
}
// MaxNonce is the maximum value of n that is allowed. ErrMaxNonce is returned
// by Encrypt and Decrypt after this has been reached. 2^64-1 is reserved for rekeys.
const MaxNonce = uint64(math.MaxUint64) - 1
var ErrMaxNonce = errors.New("noise: cipherstate has reached maximum n, a new handshake must be performed")
var ErrCipherSuiteCopied = errors.New("noise: CipherSuite has been copied, state is invalid")
// Encrypt encrypts the plaintext and then appends the ciphertext and an
// authentication tag across the ciphertext and optional authenticated data to
// out. This method automatically increments the nonce after every call, so
// messages must be decrypted in the same order. ErrMaxNonce is returned after
// the maximum nonce of 2^64-2 is reached.
func (s *CipherState) Encrypt(out, ad, plaintext []byte) ([]byte, error) {
if s.invalid {
return nil, ErrCipherSuiteCopied
}
if s.n > MaxNonce {
return nil, ErrMaxNonce
}
out = s.c.Encrypt(out, s.n, ad, plaintext)
s.n++
return out, nil
}
// Decrypt checks the authenticity of the ciphertext and authenticated data and
// then decrypts and appends the plaintext to out. This method automatically
// increments the nonce after every call, messages must be provided in the same
// order that they were encrypted with no missing messages. ErrMaxNonce is
// returned after the maximum nonce of 2^64-2 is reached.
func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) {
if s.invalid {
return nil, ErrCipherSuiteCopied
}
if s.n > MaxNonce {
return nil, ErrMaxNonce
}
out, err := s.c.Decrypt(out, s.n, ad, ciphertext)
if err != nil {
return nil, err
}
s.n++
return out, nil
}
// Cipher returns the low-level symmetric encryption primitive. It should only
// be used if nonces need to be managed manually, for example with a network
// protocol that can deliver out-of-order messages. This is dangerous, users
// must ensure that they are incrementing a nonce after every encrypt operation.
// After calling this method, it is an error to call Encrypt/Decrypt on the
// CipherState.
func (s *CipherState) Cipher() Cipher {
s.invalid = true
return s.c
}
// Nonce returns the current value of n. This can be used to determine if a
// new handshake should be performed due to approaching MaxNonce.
func (s *CipherState) Nonce() uint64 {
return s.n
}
func (s *CipherState) Rekey() {
var zeros [32]byte
var out []byte
out = s.c.Encrypt(out, math.MaxUint64, []byte{}, zeros[:])
copy(s.k[:], out[:32])
s.c = s.cs.Cipher(s.k)
}
type symmetricState struct {
CipherState
hasK bool
ck []byte
h []byte
prevCK []byte
prevH []byte
}
func (s *symmetricState) InitializeSymmetric(handshakeName []byte) {
h := s.cs.Hash()
if len(handshakeName) <= h.Size() {
s.h = make([]byte, h.Size())
copy(s.h, handshakeName)
} else {
h.Write(handshakeName)
s.h = h.Sum(nil)
}
s.ck = make([]byte, len(s.h))
copy(s.ck, s.h)
}
func (s *symmetricState) MixKey(dhOutput []byte) {
s.n = 0
s.hasK = true
var hk []byte
s.ck, hk, _ = hkdf(s.cs.Hash, 2, s.ck[:0], s.k[:0], nil, s.ck, dhOutput)
copy(s.k[:], hk)
s.c = s.cs.Cipher(s.k)
}
func (s *symmetricState) MixHash(data []byte) {
h := s.cs.Hash()
h.Write(s.h)
h.Write(data)
s.h = h.Sum(s.h[:0])
}
func (s *symmetricState) MixKeyAndHash(data []byte) {
var hk []byte
var temp []byte
s.ck, temp, hk = hkdf(s.cs.Hash, 3, s.ck[:0], temp, s.k[:0], s.ck, data)
s.MixHash(temp)
copy(s.k[:], hk)
s.c = s.cs.Cipher(s.k)
s.n = 0
s.hasK = true
}
func (s *symmetricState) EncryptAndHash(out, plaintext []byte) ([]byte, error) {
if !s.hasK {
s.MixHash(plaintext)
return append(out, plaintext...), nil
}
ciphertext, err := s.Encrypt(out, s.h, plaintext)
if err != nil {
return nil, err
}
s.MixHash(ciphertext[len(out):])
return ciphertext, nil
}
func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) {
if !s.hasK {
s.MixHash(data)
return append(out, data...), nil
}
plaintext, err := s.Decrypt(out, s.h, data)
if err != nil {
return nil, err
}
s.MixHash(data)
return plaintext, nil
}
func (s *symmetricState) Split() (*CipherState, *CipherState) {
s1, s2 := &CipherState{cs: s.cs}, &CipherState{cs: s.cs}
hk1, hk2, _ := hkdf(s.cs.Hash, 2, s1.k[:0], s2.k[:0], nil, s.ck, nil)
copy(s1.k[:], hk1)
copy(s2.k[:], hk2)
s1.c = s.cs.Cipher(s1.k)
s2.c = s.cs.Cipher(s2.k)
return s1, s2
}
func (s *symmetricState) Checkpoint() {
if len(s.ck) > cap(s.prevCK) {
s.prevCK = make([]byte, len(s.ck))
}
s.prevCK = s.prevCK[:len(s.ck)]
copy(s.prevCK, s.ck)
if len(s.h) > cap(s.prevH) {
s.prevH = make([]byte, len(s.h))
}
s.prevH = s.prevH[:len(s.h)]
copy(s.prevH, s.h)
}
func (s *symmetricState) Rollback() {
s.ck = s.ck[:len(s.prevCK)]
copy(s.ck, s.prevCK)
s.h = s.h[:len(s.prevH)]
copy(s.h, s.prevH)
}
// A MessagePattern is a single message or operation used in a Noise handshake.
type MessagePattern int
// A HandshakePattern is a list of messages and operations that are used to
// perform a specific Noise handshake.
type HandshakePattern struct {
Name string
InitiatorPreMessages []MessagePattern
ResponderPreMessages []MessagePattern
Messages [][]MessagePattern
}
const (
MessagePatternS MessagePattern = iota
MessagePatternE
MessagePatternDHEE
MessagePatternDHES
MessagePatternDHSE
MessagePatternDHSS
MessagePatternPSK
)
// MaxMsgLen is the maximum number of bytes that can be sent in a single Noise
// message.
const MaxMsgLen = 65535
// A HandshakeState tracks the state of a Noise handshake. It may be discarded
// after the handshake is complete.
type HandshakeState struct {
ss symmetricState
s DHKey // local static keypair
e DHKey // local ephemeral keypair
rs []byte // remote party's static public key
re []byte // remote party's ephemeral public key
psk []byte // preshared key, maybe zero length
messagePatterns [][]MessagePattern
shouldWrite bool
initiator bool
msgIdx int
rng io.Reader
}
// A Config provides the details necessary to process a Noise handshake. It is
// never modified by this package, and can be reused.
type Config struct {
// CipherSuite is the set of cryptographic primitives that will be used.
CipherSuite CipherSuite
// Random is the source for cryptographically appropriate random bytes. If
// zero, it is automatically configured.
Random io.Reader
// Pattern is the pattern for the handshake.
Pattern HandshakePattern
// Initiator must be true if the first message in the handshake will be sent
// by this peer.
Initiator bool
// Prologue is an optional message that has already be communicated and must
// be identical on both sides for the handshake to succeed.
Prologue []byte
// PresharedKey is the optional preshared key for the handshake.
PresharedKey []byte
// PresharedKeyPlacement specifies the placement position of the PSK token
// when PresharedKey is specified
PresharedKeyPlacement int
// StaticKeypair is this peer's static keypair, required if part of the
// handshake.
StaticKeypair DHKey
// EphemeralKeypair is this peer's ephemeral keypair that was provided as
// a pre-message in the handshake.
EphemeralKeypair DHKey
// PeerStatic is the static public key of the remote peer that was provided
// as a pre-message in the handshake.
PeerStatic []byte
// PeerEphemeral is the ephemeral public key of the remote peer that was
// provided as a pre-message in the handshake.
PeerEphemeral []byte
}
// NewHandshakeState starts a new handshake using the provided configuration.
func NewHandshakeState(c Config) (*HandshakeState, error) {
hs := &HandshakeState{
s: c.StaticKeypair,
e: c.EphemeralKeypair,
rs: c.PeerStatic,
psk: c.PresharedKey,
messagePatterns: c.Pattern.Messages,
shouldWrite: c.Initiator,
initiator: c.Initiator,
rng: c.Random,
}
if hs.rng == nil {
hs.rng = rand.Reader
}
if len(c.PeerEphemeral) > 0 {
hs.re = make([]byte, len(c.PeerEphemeral))
copy(hs.re, c.PeerEphemeral)
}
hs.ss.cs = c.CipherSuite
pskModifier := ""
if len(hs.psk) > 0 {
if len(hs.psk) != 32 {
return nil, errors.New("noise: specification mandates 256-bit preshared keys")
}
pskModifier = fmt.Sprintf("psk%d", c.PresharedKeyPlacement)
hs.messagePatterns = append([][]MessagePattern(nil), hs.messagePatterns...)
if c.PresharedKeyPlacement == 0 {
hs.messagePatterns[0] = append([]MessagePattern{MessagePatternPSK}, hs.messagePatterns[0]...)
} else {
hs.messagePatterns[c.PresharedKeyPlacement-1] = append(hs.messagePatterns[c.PresharedKeyPlacement-1], MessagePatternPSK)
}
}
hs.ss.InitializeSymmetric([]byte("Noise_" + c.Pattern.Name + pskModifier + "_" + string(hs.ss.cs.Name())))
hs.ss.MixHash(c.Prologue)
for _, m := range c.Pattern.InitiatorPreMessages {
switch {
case c.Initiator && m == MessagePatternS:
hs.ss.MixHash(hs.s.Public)
case c.Initiator && m == MessagePatternE:
hs.ss.MixHash(hs.e.Public)
case !c.Initiator && m == MessagePatternS:
hs.ss.MixHash(hs.rs)
case !c.Initiator && m == MessagePatternE:
hs.ss.MixHash(hs.re)
}
}
for _, m := range c.Pattern.ResponderPreMessages {
switch {
case !c.Initiator && m == MessagePatternS:
hs.ss.MixHash(hs.s.Public)
case !c.Initiator && m == MessagePatternE:
hs.ss.MixHash(hs.e.Public)
case c.Initiator && m == MessagePatternS:
hs.ss.MixHash(hs.rs)
case c.Initiator && m == MessagePatternE:
hs.ss.MixHash(hs.re)
}
}
return hs, nil
}
// WriteMessage appends a handshake message to out. The message will include the
// optional payload if provided. If the handshake is completed by the call, two
// CipherStates will be returned, one is used for encryption of messages to the
// remote peer, the other is used for decryption of messages from the remote
// peer. It is an error to call this method out of sync with the handshake
// pattern.
func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) {
if !s.shouldWrite {
return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
}
if s.msgIdx > len(s.messagePatterns)-1 {
return nil, nil, nil, errors.New("noise: no handshake messages left")
}
if len(payload) > MaxMsgLen {
return nil, nil, nil, errors.New("noise: message is too long")
}
var err error
for _, msg := range s.messagePatterns[s.msgIdx] {
switch msg {
case MessagePatternE:
e, err := s.ss.cs.GenerateKeypair(s.rng)
if err != nil {
return nil, nil, nil, err
}
s.e = e
out = append(out, s.e.Public...)
s.ss.MixHash(s.e.Public)
if len(s.psk) > 0 {
s.ss.MixKey(s.e.Public)
}
case MessagePatternS:
if len(s.s.Public) == 0 {
return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
}
out, err = s.ss.EncryptAndHash(out, s.s.Public)
if err != nil {
return nil, nil, nil, err
}
case MessagePatternDHEE:
dh, err := s.ss.cs.DH(s.e.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternDHES:
if s.initiator {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
}
case MessagePatternDHSE:
if s.initiator {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
}
case MessagePatternDHSS:
dh, err := s.ss.cs.DH(s.s.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternPSK:
s.ss.MixKeyAndHash(s.psk)
}
}
s.shouldWrite = false
s.msgIdx++
out, err = s.ss.EncryptAndHash(out, payload)
if err != nil {
return nil, nil, nil, err
}
if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.ss.Split()
return out, cs1, cs2, nil
}
return out, nil, nil, nil
}
// ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.
var ErrShortMessage = errors.New("noise: message is too short")
// ReadMessage processes a received handshake message and appends the payload,
// if any to out. If the handshake is completed by the call, two CipherStates
// will be returned, one is used for encryption of messages to the remote peer,
// the other is used for decryption of messages from the remote peer. It is an
// error to call this method out of sync with the handshake pattern.
func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, *CipherState, error) {
if s.shouldWrite {
return nil, nil, nil, errors.New("noise: unexpected call to ReadMessage should be WriteMessage")
}
if s.msgIdx > len(s.messagePatterns)-1 {
return nil, nil, nil, errors.New("noise: no handshake messages left")
}
rsSet := false
s.ss.Checkpoint()
var err error
for _, msg := range s.messagePatterns[s.msgIdx] {
switch msg {
case MessagePatternE, MessagePatternS:
expected := s.ss.cs.DHLen()
if msg == MessagePatternS && s.ss.hasK {
expected += 16
}
if len(message) < expected {
return nil, nil, nil, ErrShortMessage
}
switch msg {
case MessagePatternE:
if cap(s.re) < s.ss.cs.DHLen() {
s.re = make([]byte, s.ss.cs.DHLen())
}
s.re = s.re[:s.ss.cs.DHLen()]
copy(s.re, message)
s.ss.MixHash(s.re)
if len(s.psk) > 0 {
s.ss.MixKey(s.re)
}
case MessagePatternS:
if len(s.rs) > 0 {
return nil, nil, nil, errors.New("noise: invalid state, rs is not nil")
}
s.rs, err = s.ss.DecryptAndHash(s.rs[:0], message[:expected])
rsSet = true
}
if err != nil {
s.ss.Rollback()
if rsSet {
s.rs = nil
}
return nil, nil, nil, err
}
message = message[expected:]
case MessagePatternDHEE:
dh, err := s.ss.cs.DH(s.e.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternDHES:
if s.initiator {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
}
case MessagePatternDHSE:
if s.initiator {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
}
case MessagePatternDHSS:
dh, err := s.ss.cs.DH(s.s.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternPSK:
s.ss.MixKeyAndHash(s.psk)
}
}
out, err = s.ss.DecryptAndHash(out, message)
if err != nil {
s.ss.Rollback()
if rsSet {
s.rs = nil
}
return nil, nil, nil, err
}
s.shouldWrite = true
s.msgIdx++
if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.ss.Split()
return out, cs1, cs2, nil
}
return out, nil, nil, nil
}
// ChannelBinding provides a value that uniquely identifies the session and can
// be used as a channel binding. It is an error to call this method before the
// handshake is complete.
func (s *HandshakeState) ChannelBinding() []byte {
return s.ss.h
}
// PeerStatic returns the static key provided by the remote peer during
// a handshake. It is an error to call this method if a handshake message
// containing a static key has not been read.
func (s *HandshakeState) PeerStatic() []byte {
return s.rs
}
// MessageIndex returns the current handshake message id
func (s *HandshakeState) MessageIndex() int {
return s.msgIdx
}
// PeerEphemeral returns the ephemeral key provided by the remote peer during
// a handshake. It is an error to call this method if a handshake message
// containing a static key has not been read.
func (s *HandshakeState) PeerEphemeral() []byte {
return s.re
}
// LocalEphemeral returns the local ephemeral key pair generated during
// a handshake.
func (s *HandshakeState) LocalEphemeral() DHKey {
return s.e
}