mirror of
https://github.com/logos-messaging/go-noise.git
synced 2026-01-02 04:53:10 +00:00
544 lines
18 KiB
Go
544 lines
18 KiB
Go
package noise
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"strings"
|
|
)
|
|
|
|
// The padding blocksize of a transport message
|
|
const NoisePaddingBlockSize = 248
|
|
|
|
// The Handshake State as in https://noiseprotocol.org/noise.html#the-handshakestate-object
|
|
// Contains
|
|
// - the local and remote ephemeral/static keys e,s,re,rs (if any)
|
|
// - the initiator flag (true if the user creating the state is the handshake initiator, false otherwise)
|
|
// - the handshakePattern (containing the handshake protocol name, and (pre)message patterns)
|
|
// This object is further extended from specifications by storing:
|
|
// - a message pattern index msgPatternIdx indicating the next handshake message pattern to process
|
|
// - the user's preshared psk, if any
|
|
type HandshakeState struct {
|
|
s Keypair
|
|
e Keypair
|
|
rs []byte
|
|
re []byte
|
|
ss *SymmetricState
|
|
initiator bool
|
|
handshakePattern HandshakePattern
|
|
msgPatternIdx int
|
|
psk []byte
|
|
}
|
|
|
|
func NewHandshakeState(hsPattern HandshakePattern, psk []byte) *HandshakeState {
|
|
return &HandshakeState{
|
|
// By default the Handshake State initiator flag is set to false
|
|
// Will be set to true when the user associated to the handshake state starts an handshake
|
|
initiator: false,
|
|
handshakePattern: hsPattern,
|
|
psk: psk,
|
|
ss: NewSymmetricState(hsPattern),
|
|
msgPatternIdx: 0,
|
|
}
|
|
}
|
|
|
|
func (h *HandshakeState) Equals(b HandshakeState) bool {
|
|
if !bytes.Equal(h.s.Private, b.s.Private) {
|
|
return false
|
|
}
|
|
if !bytes.Equal(h.s.Public, b.s.Public) {
|
|
return false
|
|
}
|
|
if !bytes.Equal(h.e.Private, b.e.Private) {
|
|
return false
|
|
}
|
|
if !bytes.Equal(h.e.Public, b.e.Public) {
|
|
return false
|
|
}
|
|
|
|
if !bytes.Equal(h.rs, b.rs) {
|
|
return false
|
|
}
|
|
|
|
if !bytes.Equal(h.re, b.re) {
|
|
return false
|
|
}
|
|
|
|
if !h.ss.Equals(b.ss) {
|
|
return false
|
|
}
|
|
|
|
if h.initiator != b.initiator {
|
|
return false
|
|
}
|
|
|
|
if !h.handshakePattern.Equals(b.handshakePattern) {
|
|
return false
|
|
}
|
|
|
|
if h.msgPatternIdx != b.msgPatternIdx {
|
|
return false
|
|
}
|
|
|
|
if !bytes.Equal(h.psk, b.psk) {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (h *HandshakeState) genMessageNametagSecrets() (nms1 []byte, nms2 []byte) {
|
|
keyLen := h.handshakePattern.hashFn().Size()
|
|
output := getHKDF(h.handshakePattern.hashFn, h.ss.h, []byte{}, keyLen*2)
|
|
nms1 = output[:keyLen]
|
|
nms2 = output[keyLen:]
|
|
return
|
|
}
|
|
|
|
// Uses the cryptographic information stored in the input handshake state to generate a random message nametag
|
|
// In current implementation the messageNametag = HKDF(handshake hash value), but other derivation mechanisms can be implemented
|
|
func (h *HandshakeState) MessageNametag() MessageNametag {
|
|
output := getHKDF(h.handshakePattern.hashFn, h.ss.h, []byte{}, MessageNametagLength)
|
|
return BytesToMessageNametag(output)
|
|
}
|
|
|
|
// Handshake Processing
|
|
|
|
// Based on the message handshake direction and if the user is or not the initiator, returns a boolean tuple telling if the user
|
|
// has to read or write the next handshake message
|
|
func (h *HandshakeState) getReadingWritingState(direction MessageDirection) (reading bool, writing bool) {
|
|
if h.initiator && direction == Right {
|
|
// I'm Alice and direction is ->
|
|
writing = true
|
|
} else if h.initiator && direction == Left {
|
|
// I'm Alice and direction is <-
|
|
reading = true
|
|
} else if !h.initiator && direction == Right {
|
|
// I'm Bob and direction is ->
|
|
reading = true
|
|
} else if !h.initiator && direction == Left {
|
|
// I'm Bob and direction is <-
|
|
writing = true
|
|
}
|
|
return reading, writing
|
|
}
|
|
|
|
// Checks if a pre-message is valid according to Noise specifications
|
|
// http://www.noiseprotocol.org/noise.html#handshake-patterns
|
|
func (h *HandshakeState) isValid(msg []PreMessagePattern) bool {
|
|
// Non-empty pre-messages can only have patterns "e", "s", "e,s" in each direction
|
|
allowedPatterns := []PreMessagePattern{
|
|
NewPreMessagePattern(Right, []NoiseTokens{S}),
|
|
NewPreMessagePattern(Right, []NoiseTokens{E}),
|
|
NewPreMessagePattern(Right, []NoiseTokens{E, S}),
|
|
NewPreMessagePattern(Left, []NoiseTokens{S}),
|
|
NewPreMessagePattern(Left, []NoiseTokens{E}),
|
|
NewPreMessagePattern(Left, []NoiseTokens{E, S}),
|
|
}
|
|
|
|
// We check if pre message patterns are allowed
|
|
for _, p := range msg {
|
|
found := false
|
|
for _, allowed := range allowedPatterns {
|
|
if allowed.Equals(p) {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// Handshake messages processing procedures
|
|
|
|
// Processes pre-message patterns
|
|
func (h *HandshakeState) processPreMessagePatternTokens(inPreMessagePKs []*NoisePublicKey) error {
|
|
// I make a copy of the input pre-message public keys, so that I can easily delete processed ones without using iterators/counters
|
|
preMessagePKs := append([]*NoisePublicKey(nil), inPreMessagePKs...)
|
|
|
|
// Here we store currently processed pre message public key
|
|
var currPK *NoisePublicKey
|
|
|
|
// We retrieve the pre-message patterns to process, if any
|
|
// If none, there's nothing to do
|
|
if len(h.handshakePattern.premessagePatterns) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// If not empty, we check that pre-message is valid according to Noise specifications
|
|
if !h.isValid(h.handshakePattern.premessagePatterns) {
|
|
return errors.New("invalid pre-message in handshake")
|
|
}
|
|
|
|
// We iterate over each pattern contained in the pre-message
|
|
for _, messagePattern := range h.handshakePattern.premessagePatterns {
|
|
direction := messagePattern.direction
|
|
tokens := messagePattern.tokens
|
|
|
|
// We get if the user is reading or writing the current pre-message pattern
|
|
reading, writing := h.getReadingWritingState(direction)
|
|
|
|
// We process each message pattern token
|
|
for _, token := range tokens {
|
|
// We process the pattern token
|
|
switch token {
|
|
case E:
|
|
// We expect an ephemeral key, so we attempt to read it (next PK to process will always be at index 0 of preMessagePKs)
|
|
if len(preMessagePKs) > 0 {
|
|
currPK = preMessagePKs[0]
|
|
} else {
|
|
return errors.New("noise pre-message read e, expected a public key")
|
|
}
|
|
|
|
// If user is reading the "e" token
|
|
if reading {
|
|
// We check if current key is encrypted or not. We assume pre-message public keys are all unencrypted on users' end
|
|
if currPK.Flag == 0 {
|
|
// Sets re and calls MixHash(re.public_key).
|
|
h.re = currPK.Public
|
|
h.ss.mixHash(h.re)
|
|
} else {
|
|
return errors.New("noise read e, incorrect encryption flag for pre-message public key")
|
|
}
|
|
// If user is writing the "e" token
|
|
} else if writing {
|
|
// When writing, the user is sending a public key,
|
|
// We check that the public part corresponds to the set local key and we call MixHash(e.public_key).
|
|
if bytes.Equal(h.e.Public, currPK.Public) {
|
|
h.ss.mixHash(h.e.Public)
|
|
} else {
|
|
return errors.New("noise pre-message e key doesn't correspond to locally set e key pair")
|
|
}
|
|
}
|
|
|
|
// Noise specification: section 9.2
|
|
// In non-PSK handshakes, the "e" token in a pre-message pattern or message pattern always results
|
|
// in a call to MixHash(e.public_key).
|
|
// In a PSK handshake, all of these calls are followed by MixKey(e.public_key).
|
|
if strings.Contains(h.handshakePattern.name, string(PSK)) {
|
|
h.ss.mixKey(currPK.Public)
|
|
}
|
|
|
|
// We delete processed public key
|
|
preMessagePKs = preMessagePKs[1:]
|
|
case S:
|
|
// We expect a static key, so we attempt to read it (next PK to process will always be at index of preMessagePKs)
|
|
if len(preMessagePKs) > 0 {
|
|
currPK = preMessagePKs[0]
|
|
} else {
|
|
return errors.New("noise pre-message read s, expected a public key")
|
|
}
|
|
|
|
// If user is reading the "s" token
|
|
if reading {
|
|
// We check if current key is encrypted or not. We assume pre-message public keys are all unencrypted on users' end
|
|
if currPK.Flag == 0 {
|
|
// Sets rs and calls MixHash(rs.public_key).
|
|
h.rs = currPK.Public
|
|
h.ss.mixHash(h.rs)
|
|
} else {
|
|
return errors.New("noise read s, incorrect encryption flag for pre-message public key")
|
|
}
|
|
|
|
// If user is writing the "s" token
|
|
} else if writing {
|
|
// If writing, it means that the user is sending a public key,
|
|
// We check that the public part corresponds to the set local key and we call MixHash(s.public_key).
|
|
if bytes.Equal(h.s.Public, currPK.Public) {
|
|
h.ss.mixHash(h.s.Public)
|
|
} else {
|
|
return errors.New("noise pre-message s key doesn't correspond to locally set s key pair")
|
|
}
|
|
}
|
|
|
|
// Noise specification: section 9.2
|
|
// In non-PSK handshakes, the "e" token in a pre-message pattern or message pattern always results
|
|
// in a call to MixHash(e.public_key).
|
|
// In a PSK handshake, all of these calls are followed by MixKey(e.public_key).
|
|
if strings.Contains(h.handshakePattern.name, string(PSK)) {
|
|
h.ss.mixKey(currPK.Public)
|
|
}
|
|
|
|
// We delete processed public key
|
|
preMessagePKs = preMessagePKs[1:]
|
|
default:
|
|
return errors.New("invalid Token for pre-message pattern")
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// This procedure encrypts/decrypts the implicit payload attached at the end of every message pattern
|
|
// An optional extraAd to pass extra additional data in encryption/decryption can be set (useful to authenticate messageNametag)
|
|
func (h *HandshakeState) processMessagePatternPayload(transportMessage []byte, extraAd []byte) ([]byte, error) {
|
|
var payload []byte
|
|
var err error
|
|
|
|
// We retrieve current message pattern (direction + tokens) to process
|
|
direction := h.handshakePattern.messagePatterns[h.msgPatternIdx].direction
|
|
|
|
// We get if the user is reading or writing the input handshake message
|
|
reading, writing := h.getReadingWritingState(direction)
|
|
|
|
// We decrypt the transportMessage, if any
|
|
if reading {
|
|
payload, err = h.ss.decryptAndHash(transportMessage, extraAd)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
payload, err = PKCS7_Unpad(payload, NoisePaddingBlockSize)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else if writing {
|
|
payload, err = PKCS7_Pad(transportMessage, NoisePaddingBlockSize)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
payload, err = h.ss.encryptAndHash(payload, extraAd)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
return nil, errors.New("undefined state")
|
|
}
|
|
return payload, nil
|
|
}
|
|
|
|
// We process an input handshake message according to current handshake state and we return the next handshake step's handshake message
|
|
func (h *HandshakeState) processMessagePatternTokens(inputHandshakeMessage []*NoisePublicKey) ([]*NoisePublicKey, error) {
|
|
// We retrieve current message pattern (direction + tokens) to process
|
|
messagePattern := h.handshakePattern.messagePatterns[h.msgPatternIdx]
|
|
direction := messagePattern.direction
|
|
tokens := messagePattern.tokens
|
|
|
|
// We get if the user is reading or writing the input handshake message
|
|
reading, writing := h.getReadingWritingState(direction)
|
|
|
|
// I make a copy of the handshake message so that I can easily delete processed PKs without using iterators/counters
|
|
// (Possibly) non-empty if reading
|
|
inHandshakeMessage := append([]*NoisePublicKey(nil), inputHandshakeMessage...)
|
|
|
|
// The party's output public keys
|
|
// (Possibly) non-empty if writing
|
|
var outHandshakeMessage []*NoisePublicKey
|
|
|
|
// In currPK we store the currently processed public key from the handshake message
|
|
var currPK *NoisePublicKey
|
|
|
|
// We process each message pattern token
|
|
for _, token := range tokens {
|
|
switch token {
|
|
case E:
|
|
// If user is reading the "s" token
|
|
if reading {
|
|
// We expect an ephemeral key, so we attempt to read it (next PK to process will always be at index 0 of preMessagePKs)
|
|
if len(inHandshakeMessage) > 0 {
|
|
currPK = inHandshakeMessage[0]
|
|
} else {
|
|
return nil, errors.New("noise read e, expected a public key")
|
|
}
|
|
|
|
// We check if current key is encrypted or not
|
|
// Note: by specification, ephemeral keys should always be unencrypted. But we support encrypted ones.
|
|
if currPK.Flag == 0 {
|
|
// Unencrypted Public Key
|
|
// Sets re and calls MixHash(re.public_key).
|
|
h.re = currPK.Public
|
|
h.ss.mixHash(h.re)
|
|
|
|
// The following is out of specification: we call decryptAndHash for encrypted ephemeral keys, similarly as happens for (encrypted) static keys
|
|
} else if currPK.Flag == 1 {
|
|
// Encrypted public key
|
|
// Decrypts re, sets re and calls MixHash(re.public_key).
|
|
decRe, err := h.ss.decryptAndHash(currPK.Public, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
h.re = decRe
|
|
} else {
|
|
return nil, errors.New("noise read e, incorrect encryption flag for public key")
|
|
}
|
|
|
|
// Noise specification: section 9.2
|
|
// In non-PSK handshakes, the "e" token in a pre-message pattern or message pattern always results
|
|
// in a call to MixHash(e.public_key).
|
|
// In a PSK handshake, all of these calls are followed by MixKey(e.public_key).
|
|
if strings.Contains(h.handshakePattern.name, string(PSK)) {
|
|
h.ss.mixKey(h.re)
|
|
}
|
|
|
|
// We delete processed public key
|
|
inHandshakeMessage = inHandshakeMessage[1:]
|
|
|
|
// If user is writing the "e" token
|
|
} else if writing {
|
|
// We generate a new ephemeral keypair
|
|
e, err := DH25519.GenerateKeypair()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
h.e = e
|
|
|
|
// We update the state
|
|
h.ss.mixHash(h.e.Public)
|
|
|
|
// Noise specification: section 9.2
|
|
// In non-PSK handshakes, the "e" token in a pre-message pattern or message pattern always results
|
|
// in a call to MixHash(e.public_key).
|
|
// In a PSK handshake, all of these calls are followed by MixKey(e.public_key).
|
|
if strings.Contains(h.handshakePattern.name, string(PSK)) {
|
|
h.ss.mixKey(h.e.Public)
|
|
}
|
|
|
|
// We add the ephemeral public key to the Waku payload
|
|
outHandshakeMessage = append(outHandshakeMessage, byteToNoisePublicKey(h.handshakePattern.dhKey, h.e.Public))
|
|
}
|
|
case S:
|
|
// If user is reading the "s" token
|
|
if reading {
|
|
// We expect a static key, so we attempt to read it (next PK to process will always be at index 0 of preMessagePKs)
|
|
if len(inHandshakeMessage) > 0 {
|
|
currPK = inHandshakeMessage[0]
|
|
} else {
|
|
return nil, errors.New("noise read s, expected a public key")
|
|
}
|
|
|
|
// We check if current key is encrypted or not
|
|
if currPK.Flag == 0 {
|
|
// Unencrypted Public Key
|
|
// Sets re and calls MixHash(re.public_key).
|
|
h.rs = currPK.Public
|
|
h.ss.mixHash(h.rs)
|
|
} else if currPK.Flag == 1 {
|
|
// Encrypted public key
|
|
// Decrypts rs, sets rs and calls MixHash(rs.public_key).
|
|
decRS, err := h.ss.decryptAndHash(currPK.Public, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
h.rs = decRS
|
|
} else {
|
|
return nil, errors.New("noise read s, incorrect encryption flag for public key")
|
|
}
|
|
|
|
// We delete processed public key
|
|
inHandshakeMessage = inHandshakeMessage[1:]
|
|
|
|
// If user is writing the "s" token
|
|
} else if writing {
|
|
// If the local static key is not set (the handshake state was not properly initialized), we raise an error
|
|
if h.s.IsDefault() {
|
|
return nil, errors.New("static key not set")
|
|
}
|
|
|
|
// We encrypt the public part of the static key in case a key is set in the Cipher State
|
|
// That is, encS may either be an encrypted or unencrypted static key.
|
|
encS, err := h.ss.encryptAndHash(h.s.Public, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// We add the (encrypted) static public key to the Waku payload
|
|
// Note that encS = (Enc(s) || tag) if encryption key is set, otherwise encS = s.
|
|
// We distinguish these two cases by checking length of encryption and we set the proper encryption flag
|
|
if len(encS) > h.handshakePattern.dhKey.DHLen() {
|
|
outHandshakeMessage = append(outHandshakeMessage, byteToNoisePublicKey(h.handshakePattern.dhKey, encS))
|
|
} else {
|
|
outHandshakeMessage = append(outHandshakeMessage, byteToNoisePublicKey(h.handshakePattern.dhKey, encS))
|
|
}
|
|
}
|
|
case PSK:
|
|
// If user is reading the "psk" token
|
|
|
|
// Calls MixKeyAndHash(psk)
|
|
h.ss.mixKeyAndHash(h.psk)
|
|
case EE:
|
|
// If user is reading the "ee" token
|
|
|
|
// If local and/or remote ephemeral keys are not set, we raise an error
|
|
if h.e.IsDefault() || len(h.re) == 0 {
|
|
return nil, errors.New("local or remote ephemeral key not set")
|
|
}
|
|
|
|
// Calls MixKey(DH(e, re)).
|
|
k, err := h.handshakePattern.dhKey.DH(h.e.Private, h.re)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
h.ss.mixKey(k)
|
|
case ES:
|
|
// If user is reading the "es" token
|
|
|
|
// We check if keys are correctly set.
|
|
// If both present, we call MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if responder.
|
|
if h.initiator {
|
|
if h.e.IsDefault() || len(h.rs) == 0 {
|
|
return nil, errors.New("local or remote ephemeral/static key not set")
|
|
}
|
|
|
|
k, err := h.handshakePattern.dhKey.DH(h.e.Private, h.rs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
h.ss.mixKey(k)
|
|
} else {
|
|
if len(h.re) == 0 || h.s.IsDefault() {
|
|
return nil, errors.New("local or remote ephemeral/static key not set")
|
|
}
|
|
|
|
k, err := h.handshakePattern.dhKey.DH(h.s.Private, h.re)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
h.ss.mixKey(k)
|
|
}
|
|
case SE:
|
|
// If user is reading the "se" token
|
|
|
|
// We check if keys are correctly set.
|
|
// If both present, call MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if responder.
|
|
if h.initiator {
|
|
if h.s.IsDefault() || len(h.re) == 0 {
|
|
return nil, errors.New("local or remote ephemeral/static key not set")
|
|
}
|
|
|
|
k, err := h.handshakePattern.dhKey.DH(h.s.Private, h.re)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
h.ss.mixKey(k)
|
|
} else {
|
|
if len(h.rs) == 0 || h.e.IsDefault() {
|
|
return nil, errors.New("local or remote ephemeral/static key not set")
|
|
}
|
|
|
|
k, err := h.handshakePattern.dhKey.DH(h.e.Private, h.rs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
h.ss.mixKey(k)
|
|
}
|
|
case SS:
|
|
// If user is reading the "ss" token
|
|
|
|
// If local and/or remote static keys are not set, we raise an error
|
|
if h.s.IsDefault() || len(h.rs) == 0 {
|
|
return nil, errors.New("local or remote static key not set")
|
|
}
|
|
|
|
// Calls MixKey(DH(s, rs)).
|
|
k, err := h.handshakePattern.dhKey.DH(h.s.Private, h.rs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
h.ss.mixKey(k)
|
|
}
|
|
}
|
|
|
|
return outHandshakeMessage, nil
|
|
}
|