nwaku/waku/waku_noise/noise_handshake_processing.nim
2024-03-16 00:08:47 +01:00

673 lines
26 KiB
Nim

# Waku Noise Protocols for Waku Payload Encryption
## See spec for more details:
## https://github.com/vacp2p/rfc/tree/master/content/docs/rfcs/35
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[options, strutils, tables]
import chronos
import chronicles
import bearssl/rand
import stew/results
import libp2p/crypto/[chacha20poly1305, curve25519]
import ./noise_types
import ./noise
import ./noise_utils
logScope:
topics = "waku noise"
#################################################################
# Handshake Processing
#################################
## Utilities
#################################
# Based on the message handshake direction and if the user is or not the initiator, returns a boolean tuple telling if the user
# has to read or write the next handshake message
proc getReadingWritingState(
hs: HandshakeState, direction: MessageDirection
): (bool, bool) =
var reading, writing: bool
if hs.initiator and direction == D_r:
# I'm Alice and direction is ->
reading = false
writing = true
elif hs.initiator and direction == D_l:
# I'm Alice and direction is <-
reading = true
writing = false
elif not hs.initiator and direction == D_r:
# I'm Bob and direction is ->
reading = true
writing = false
elif not hs.initiator and direction == D_l:
# I'm Bob and direction is <-
reading = false
writing = true
return (reading, writing)
# Checks if a pre-message is valid according to Noise specifications
# http://www.noiseprotocol.org/noise.html#handshake-patterns
proc isValid(msg: seq[PreMessagePattern]): bool =
var isValid: bool = true
# Non-empty pre-messages can only have patterns "e", "s", "e,s" in each direction
let allowedPatterns: seq[PreMessagePattern] =
@[
PreMessagePattern(direction: D_r, tokens: @[T_s]),
PreMessagePattern(direction: D_r, tokens: @[T_e]),
PreMessagePattern(direction: D_r, tokens: @[T_e, T_s]),
PreMessagePattern(direction: D_l, tokens: @[T_s]),
PreMessagePattern(direction: D_l, tokens: @[T_e]),
PreMessagePattern(direction: D_l, tokens: @[T_e, T_s]),
]
# We check if pre message patterns are allowed
for pattern in msg:
if not (pattern in allowedPatterns):
isValid = false
break
return isValid
#################################
# Handshake messages processing procedures
#################################
# Processes pre-message patterns
proc processPreMessagePatternTokens(
hs: var HandshakeState, inPreMessagePKs: seq[NoisePublicKey] = @[]
) {.
raises: [Defect, NoiseMalformedHandshake, NoiseHandshakeError, NoisePublicKeyError]
.} =
var
# I make a copy of the input pre-message public keys, so that I can easily delete processed ones without using iterators/counters
preMessagePKs = inPreMessagePKs
# Here we store currently processed pre message public key
currPK: NoisePublicKey
# We retrieve the pre-message patterns to process, if any
# If none, there's nothing to do
if hs.handshakePattern.preMessagePatterns == EmptyPreMessage:
return
# If not empty, we check that pre-message is valid according to Noise specifications
if isValid(hs.handshakePattern.preMessagePatterns) == false:
raise newException(NoiseMalformedHandshake, "Invalid pre-message in handshake")
# We iterate over each pattern contained in the pre-message
for messagePattern in hs.handshakePattern.preMessagePatterns:
let
direction = messagePattern.direction
tokens = messagePattern.tokens
# We get if the user is reading or writing the current pre-message pattern
var (reading, writing) = getReadingWritingState(hs, direction)
# We process each message pattern token
for token in tokens:
# We process the pattern token
case token
of T_e:
# We expect an ephemeral key, so we attempt to read it (next PK to process will always be at index 0 of preMessagePKs)
if preMessagePKs.len > 0:
currPK = preMessagePKs[0]
else:
raise newException(
NoiseHandshakeError, "Noise pre-message read e, expected a public key"
)
# If user is reading the "e" token
if reading:
trace "noise pre-message read e"
# We check if current key is encrypted or not. We assume pre-message public keys are all unencrypted on users' end
if currPK.flag == 0.uint8:
# Sets re and calls MixHash(re.public_key).
hs.re = intoCurve25519Key(currPK.pk)
hs.ss.mixHash(hs.re)
else:
raise newException(
NoisePublicKeyError,
"Noise read e, incorrect encryption flag for pre-message public key",
)
# If user is writing the "e" token
elif writing:
trace "noise pre-message write e"
# When writing, the user is sending a public key,
# We check that the public part corresponds to the set local key and we call MixHash(e.public_key).
if hs.e.publicKey == intoCurve25519Key(currPK.pk):
hs.ss.mixHash(hs.e.publicKey)
else:
raise newException(
NoisePublicKeyError,
"Noise pre-message e key doesn't correspond to locally set e key pair",
)
# Noise specification: section 9.2
# In non-PSK handshakes, the "e" token in a pre-message pattern or message pattern always results
# in a call to MixHash(e.public_key).
# In a PSK handshake, all of these calls are followed by MixKey(e.public_key).
if "psk" in hs.handshakePattern.name:
hs.ss.mixKey(currPK.pk)
# We delete processed public key
preMessagePKs.delete(0)
of T_s:
# We expect a static key, so we attempt to read it (next PK to process will always be at index of preMessagePKs)
if preMessagePKs.len > 0:
currPK = preMessagePKs[0]
else:
raise newException(
NoiseHandshakeError, "Noise pre-message read s, expected a public key"
)
# If user is reading the "s" token
if reading:
trace "noise pre-message read s"
# We check if current key is encrypted or not. We assume pre-message public keys are all unencrypted on users' end
if currPK.flag == 0.uint8:
# Sets re and calls MixHash(re.public_key).
hs.rs = intoCurve25519Key(currPK.pk)
hs.ss.mixHash(hs.rs)
else:
raise newException(
NoisePublicKeyError,
"Noise read s, incorrect encryption flag for pre-message public key",
)
# If user is writing the "s" token
elif writing:
trace "noise pre-message write s"
# If writing, it means that the user is sending a public key,
# We check that the public part corresponds to the set local key and we call MixHash(s.public_key).
if hs.s.publicKey == intoCurve25519Key(currPK.pk):
hs.ss.mixHash(hs.s.publicKey)
else:
raise newException(
NoisePublicKeyError,
"Noise pre-message s key doesn't correspond to locally set s key pair",
)
# Noise specification: section 9.2
# In non-PSK handshakes, the "e" token in a pre-message pattern or message pattern always results
# in a call to MixHash(e.public_key).
# In a PSK handshake, all of these calls are followed by MixKey(e.public_key).
if "psk" in hs.handshakePattern.name:
hs.ss.mixKey(currPK.pk)
# We delete processed public key
preMessagePKs.delete(0)
else:
raise
newException(NoiseMalformedHandshake, "Invalid Token for pre-message pattern")
# This procedure encrypts/decrypts the implicit payload attached at the end of every message pattern
# An optional extraAd to pass extra additional data in encryption/decryption can be set (useful to authenticate messageNametag)
proc processMessagePatternPayload(
hs: var HandshakeState, transportMessage: seq[byte], extraAd: openArray[byte] = []
): seq[byte] {.raises: [Defect, NoiseDecryptTagError, NoiseNonceMaxError].} =
var payload: seq[byte]
# We retrieve current message pattern (direction + tokens) to process
let direction = hs.handshakePattern.messagePatterns[hs.msgPatternIdx].direction
# We get if the user is reading or writing the input handshake message
var (reading, writing) = getReadingWritingState(hs, direction)
# We decrypt the transportMessage, if any
if reading:
payload = hs.ss.decryptAndHash(transportMessage, extraAd)
payload = pkcs7_unpad(payload, NoisePaddingBlockSize)
elif writing:
payload = pkcs7_pad(transportMessage, NoisePaddingBlockSize)
payload = hs.ss.encryptAndHash(payload, extraAd)
return payload
# We process an input handshake message according to current handshake state and we return the next handshake step's handshake message
proc processMessagePatternTokens(
rng: var rand.HmacDrbgContext,
hs: var HandshakeState,
inputHandshakeMessage: seq[NoisePublicKey] = @[],
): Result[seq[NoisePublicKey], cstring] {.
raises: [
Defect, NoiseHandshakeError, NoiseMalformedHandshake, NoisePublicKeyError,
NoiseDecryptTagError, NoiseNonceMaxError,
]
.} =
# We retrieve current message pattern (direction + tokens) to process
let
messagePattern = hs.handshakePattern.messagePatterns[hs.msgPatternIdx]
direction = messagePattern.direction
tokens = messagePattern.tokens
# We get if the user is reading or writing the input handshake message
var (reading, writing) = getReadingWritingState(hs, direction)
# I make a copy of the handshake message so that I can easily delete processed PKs without using iterators/counters
# (Possibly) non-empty if reading
var inHandshakeMessage = inputHandshakeMessage
# The party's output public keys
# (Possibly) non-empty if writing
var outHandshakeMessage: seq[NoisePublicKey] = @[]
# In currPK we store the currently processed public key from the handshake message
var currPK: NoisePublicKey
# We process each message pattern token
for token in tokens:
case token
of T_e:
# If user is reading the "s" token
if reading:
trace "noise read e"
# We expect an ephemeral key, so we attempt to read it (next PK to process will always be at index 0 of preMessagePKs)
if inHandshakeMessage.len > 0:
currPK = inHandshakeMessage[0]
else:
raise newException(NoiseHandshakeError, "Noise read e, expected a public key")
# We check if current key is encrypted or not
# Note: by specification, ephemeral keys should always be unencrypted. But we support encrypted ones.
if currPK.flag == 0.uint8:
# Unencrypted Public Key
# Sets re and calls MixHash(re.public_key).
hs.re = intoCurve25519Key(currPK.pk)
hs.ss.mixHash(hs.re)
# The following is out of specification: we call decryptAndHash for encrypted ephemeral keys, similarly as happens for (encrypted) static keys
elif currPK.flag == 1.uint8:
# Encrypted public key
# Decrypts re, sets re and calls MixHash(re.public_key).
hs.re = intoCurve25519Key(hs.ss.decryptAndHash(currPK.pk))
else:
raise newException(
NoisePublicKeyError,
"Noise read e, incorrect encryption flag for public key",
)
# Noise specification: section 9.2
# In non-PSK handshakes, the "e" token in a pre-message pattern or message pattern always results
# in a call to MixHash(e.public_key).
# In a PSK handshake, all of these calls are followed by MixKey(e.public_key).
if "psk" in hs.handshakePattern.name:
hs.ss.mixKey(hs.re)
# We delete processed public key
inHandshakeMessage.delete(0)
# If user is writing the "e" token
elif writing:
trace "noise write e"
# We generate a new ephemeral keypair
hs.e = genKeyPair(rng)
# We update the state
hs.ss.mixHash(hs.e.publicKey)
# Noise specification: section 9.2
# In non-PSK handshakes, the "e" token in a pre-message pattern or message pattern always results
# in a call to MixHash(e.public_key).
# In a PSK handshake, all of these calls are followed by MixKey(e.public_key).
if "psk" in hs.handshakePattern.name:
hs.ss.mixKey(hs.e.publicKey)
# We add the ephemeral public key to the Waku payload
outHandshakeMessage.add toNoisePublicKey(getPublicKey(hs.e))
of T_s:
# If user is reading the "s" token
if reading:
trace "noise read s"
# We expect a static key, so we attempt to read it (next PK to process will always be at index 0 of preMessagePKs)
if inHandshakeMessage.len > 0:
currPK = inHandshakeMessage[0]
else:
raise newException(NoiseHandshakeError, "Noise read s, expected a public key")
# We check if current key is encrypted or not
if currPK.flag == 0.uint8:
# Unencrypted Public Key
# Sets re and calls MixHash(re.public_key).
hs.rs = intoCurve25519Key(currPK.pk)
hs.ss.mixHash(hs.rs)
elif currPK.flag == 1.uint8:
# Encrypted public key
# Decrypts rs, sets rs and calls MixHash(rs.public_key).
hs.rs = intoCurve25519Key(hs.ss.decryptAndHash(currPK.pk))
else:
raise newException(
NoisePublicKeyError,
"Noise read s, incorrect encryption flag for public key",
)
# We delete processed public key
inHandshakeMessage.delete(0)
# If user is writing the "s" token
elif writing:
trace "noise write s"
# If the local static key is not set (the handshake state was not properly initialized), we raise an error
if isDefault(hs.s):
raise newException(NoisePublicKeyError, "Static key not set")
# We encrypt the public part of the static key in case a key is set in the Cipher State
# That is, encS may either be an encrypted or unencrypted static key.
let encS = hs.ss.encryptAndHash(hs.s.publicKey)
# We add the (encrypted) static public key to the Waku payload
# Note that encS = (Enc(s) || tag) if encryption key is set, otherwise encS = s.
# We distinguish these two cases by checking length of encryption and we set the proper encryption flag
if encS.len > Curve25519Key.len:
outHandshakeMessage.add NoisePublicKey(flag: 1, pk: encS)
else:
outHandshakeMessage.add NoisePublicKey(flag: 0, pk: encS)
of T_psk:
# If user is reading the "psk" token
trace "noise psk"
# Calls MixKeyAndHash(psk)
hs.ss.mixKeyAndHash(hs.psk)
of T_ee:
# If user is reading the "ee" token
trace "noise dh ee"
# If local and/or remote ephemeral keys are not set, we raise an error
if isDefault(hs.e) or isDefault(hs.re):
raise newException(NoisePublicKeyError, "Local or remote ephemeral key not set")
# Calls MixKey(DH(e, re)).
hs.ss.mixKey(dh(hs.e.privateKey, hs.re))
of T_es:
# If user is reading the "es" token
trace "noise dh es"
# We check if keys are correctly set.
# If both present, we call MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if responder.
if hs.initiator:
if isDefault(hs.e) or isDefault(hs.rs):
raise newException(
NoisePublicKeyError, "Local or remote ephemeral/static key not set"
)
hs.ss.mixKey(dh(hs.e.privateKey, hs.rs))
else:
if isDefault(hs.re) or isDefault(hs.s):
raise newException(
NoisePublicKeyError, "Local or remote ephemeral/static key not set"
)
hs.ss.mixKey(dh(hs.s.privateKey, hs.re))
of T_se:
# If user is reading the "se" token
trace "noise dh se"
# We check if keys are correctly set.
# If both present, call MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if responder.
if hs.initiator:
if isDefault(hs.s) or isDefault(hs.re):
raise newException(
NoiseMalformedHandshake, "Local or remote ephemeral/static key not set"
)
hs.ss.mixKey(dh(hs.s.privateKey, hs.re))
else:
if isDefault(hs.rs) or isDefault(hs.e):
raise newException(
NoiseMalformedHandshake, "Local or remote ephemeral/static key not set"
)
hs.ss.mixKey(dh(hs.e.privateKey, hs.rs))
of T_ss:
# If user is reading the "ss" token
trace "noise dh ss"
# If local and/or remote static keys are not set, we raise an error
if isDefault(hs.s) or isDefault(hs.rs):
raise
newException(NoiseMalformedHandshake, "Local or remote static key not set")
# Calls MixKey(DH(s, rs)).
hs.ss.mixKey(dh(hs.s.privateKey, hs.rs))
return ok(outHandshakeMessage)
#################################
## Procedures to progress handshakes between users
#################################
# Initializes a Handshake State
proc initialize*(
hsPattern: HandshakePattern,
ephemeralKey: KeyPair = default(KeyPair),
staticKey: KeyPair = default(KeyPair),
prologue: seq[byte] = @[],
psk: seq[byte] = @[],
preMessagePKs: seq[NoisePublicKey] = @[],
initiator: bool = false,
): HandshakeState {.
raises: [Defect, NoiseMalformedHandshake, NoiseHandshakeError, NoisePublicKeyError]
.} =
var hs = HandshakeState.init(hsPattern)
hs.ss.mixHash(prologue)
hs.e = ephemeralKey
hs.s = staticKey
hs.psk = psk
hs.msgPatternIdx = 0
hs.initiator = initiator
# We process any eventual handshake pre-message pattern by processing pre-message public keys
processPreMessagePatternTokens(hs, preMessagePKs)
return hs
# Advances 1 step in handshake
# Each user in a handshake alternates writing and reading of handshake messages.
# If the user is writing the handshake message, the transport message (if not empty) and eventually a non-empty message nametag has to be passed to transportMessage and messageNametag and readPayloadV2 can be left to its default value
# It the user is reading the handshake message, the read payload v2 has to be passed to readPayloadV2 and the transportMessage can be left to its default values. Decryption is skipped if the payloadv2 read doesn't have a message nametag equal to messageNametag (empty input nametags are converted to all-0 MessageNametagLength bytes arrays)
proc stepHandshake*(
rng: var rand.HmacDrbgContext,
hs: var HandshakeState,
readPayloadV2: PayloadV2 = default(PayloadV2),
transportMessage: seq[byte] = @[],
messageNametag: openArray[byte] = [],
): Result[HandshakeStepResult, cstring] {.
raises: [
Defect, NoiseHandshakeError, NoiseMessageNametagError, NoiseMalformedHandshake,
NoisePublicKeyError, NoiseDecryptTagError, NoiseNonceMaxError,
]
.} =
var hsStepResult: HandshakeStepResult
# If there are no more message patterns left for processing
# we return an empty HandshakeStepResult
if hs.msgPatternIdx > uint8(hs.handshakePattern.messagePatterns.len - 1):
debug "stepHandshake called more times than the number of message patterns present in handshake"
return ok(hsStepResult)
# We process the next handshake message pattern
# We get if the user is reading or writing the input handshake message
let direction = hs.handshakePattern.messagePatterns[hs.msgPatternIdx].direction
var (reading, writing) = getReadingWritingState(hs, direction)
# If we write an answer at this handshake step
if writing:
# We initialize a payload v2 and we set proper protocol ID (if supported)
try:
hsStepResult.payload2.protocolId = PayloadV2ProtocolIDs[hs.handshakePattern.name]
except CatchableError:
raise newException(NoiseMalformedHandshake, "Handshake Pattern not supported")
# We set the messageNametag and the handshake and transport messages
hsStepResult.payload2.messageNametag = toMessageNametag(messageNametag)
hsStepResult.payload2.handshakeMessage = processMessagePatternTokens(rng, hs).get()
# We write the payload by passing the messageNametag as extra additional data
hsStepResult.payload2.transportMessage = processMessagePatternPayload(
hs, transportMessage, extraAd = hsStepResult.payload2.messageNametag
)
# If we read an answer during this handshake step
elif reading:
# If the read message nametag doesn't match the expected input one we raise an error
if readPayloadV2.messageNametag != toMessageNametag(messageNametag):
raise newException(
NoiseMessageNametagError,
"The message nametag of the read message doesn't match the expected one",
)
# We process the read public keys and (eventually decrypt) the read transport message
let
readHandshakeMessage = readPayloadV2.handshakeMessage
readTransportMessage = readPayloadV2.transportMessage
# Since we only read, nothing meanigful (i.e. public keys) is returned
discard processMessagePatternTokens(rng, hs, readHandshakeMessage)
# We retrieve and store the (decrypted) received transport message by passing the messageNametag as extra additional data
hsStepResult.transportMessage = processMessagePatternPayload(
hs, readTransportMessage, extraAd = readPayloadV2.messageNametag
)
else:
raise newException(
NoiseHandshakeError, "Handshake Error: neither writing or reading user"
)
# We increase the handshake state message pattern index to progress to next step
hs.msgPatternIdx += 1
return ok(hsStepResult)
# Finalizes the handshake by calling Split and assigning the proper Cipher States to users
proc finalizeHandshake*(hs: var HandshakeState): HandshakeResult =
var hsResult: HandshakeResult
## Noise specification, Section 5:
## Processing the final handshake message returns two CipherState objects,
## the first for encrypting transport messages from initiator to responder,
## and the second for messages in the other direction.
# We call Split()
let (cs1, cs2) = hs.ss.split()
# Optional: We derive a secret for the nametag derivation
let (nms1, nms2) = genMessageNametagSecrets(hs)
# We assign the proper Cipher States
if hs.initiator:
hsResult.csOutbound = cs1
hsResult.csInbound = cs2
# and nametags secrets
hsResult.nametagsInbound.secret = some(nms1)
hsResult.nametagsOutbound.secret = some(nms2)
else:
hsResult.csOutbound = cs2
hsResult.csInbound = cs1
# and nametags secrets
hsResult.nametagsInbound.secret = some(nms2)
hsResult.nametagsOutbound.secret = some(nms1)
# We initialize the message nametags inbound/outbound buffers
hsResult.nametagsInbound.initNametagsBuffer
hsResult.nametagsOutbound.initNametagsBuffer
# We store the optional fields rs and h
hsResult.rs = hs.rs
hsResult.h = hs.ss.h
return hsResult
#################################
# After-handshake procedures
#################################
## Noise specification, Section 5:
## Transport messages are then encrypted and decrypted by calling EncryptWithAd()
## and DecryptWithAd() on the relevant CipherState with zero-length associated data.
## If DecryptWithAd() signals an error due to DECRYPT() failure, then the input message is discarded.
## The application may choose to delete the CipherState and terminate the session on such an error,
## or may continue to attempt communications. If EncryptWithAd() or DecryptWithAd() signal an error
## due to nonce exhaustion, then the application must delete the CipherState and terminate the session.
# Writes an encrypted message using the proper Cipher State
proc writeMessage*(
hsr: var HandshakeResult,
transportMessage: seq[byte],
outboundMessageNametagBuffer: var MessageNametagBuffer,
): PayloadV2 {.raises: [Defect, NoiseNonceMaxError].} =
var payload2: PayloadV2
# We set the message nametag using the input buffer
payload2.messageNametag = pop(outboundMessageNametagBuffer)
# According to 35/WAKU2-NOISE RFC, no Handshake protocol information is sent when exchanging messages
# This correspond to setting protocol-id to 0
payload2.protocolId = 0.uint8
# We pad the transport message
let paddedTransportMessage = pkcs7_pad(transportMessage, NoisePaddingBlockSize)
# Encryption is done with zero-length associated data as per specification
payload2.transportMessage = encryptWithAd(
hsr.csOutbound, ad = @(payload2.messageNametag), plaintext = paddedTransportMessage
)
return payload2
# Reads an encrypted message using the proper Cipher State
# Decryption is attempted only if the input PayloadV2 has a messageNametag equal to the one expected
proc readMessage*(
hsr: var HandshakeResult,
readPayload2: PayloadV2,
inboundMessageNametagBuffer: var MessageNametagBuffer,
): Result[seq[byte], cstring] {.
raises: [
Defect, NoiseDecryptTagError, NoiseMessageNametagError, NoiseNonceMaxError,
NoiseSomeMessagesWereLost,
]
.} =
# The output decrypted message
var message: seq[byte]
# If the message nametag does not correspond to the nametag expected in the inbound message nametag buffer
# an error is raised (to be handled externally, i.e. re-request lost messages, discard, etc.)
let nametagIsOk =
checkNametag(readPayload2.messageNametag, inboundMessageNametagBuffer).isOk
assert(nametagIsOk)
# At this point the messageNametag matches the expected nametag.
# According to 35/WAKU2-NOISE RFC, no Handshake protocol information is sent when exchanging messages
if readPayload2.protocolId == 0.uint8:
# On application level we decide to discard messages which fail decryption, without raising an error
try:
# Decryption is done with messageNametag as associated data
let paddedMessage = decryptWithAd(
hsr.csInbound,
ad = @(readPayload2.messageNametag),
ciphertext = readPayload2.transportMessage,
)
# We unpdad the decrypted message
message = pkcs7_unpad(paddedMessage, NoisePaddingBlockSize)
# The message successfully decrypted, we can delete the first element of the inbound Message Nametag Buffer
delete(inboundMessageNametagBuffer, 1)
except NoiseDecryptTagError:
debug "A read message failed decryption. Returning empty message as plaintext."
message = @[]
return ok(message)