nwaku/waku/v2/protocol/waku_noise/noise_handshake_processing.nim

586 lines
23 KiB
Nim

# Waku Noise Protocols for Waku Payload Encryption
## See spec for more details:
## https://github.com/vacp2p/rfc/tree/master/content/docs/rfcs/35
{.push raises: [Defect].}
import std/[oids, options, strutils, tables]
import chronos
import chronicles
import bearssl
import stew/[results, endians2]
import nimcrypto/[utils, sha2, hmac]
import libp2p/errors
import libp2p/crypto/[chacha20poly1305, curve25519]
import ./noise_types
import ./noise
import ./noise_utils
logScope:
topics = "wakunoise"
#################################################################
# Handshake Processing
#################################
## Utilities
#################################
# Based on the message handshake direction and if the user is or not the initiator, returns a boolean tuple telling if the user
# has to read or write the next handshake message
proc getReadingWritingState(hs: HandshakeState, direction: MessageDirection): (bool, bool) =
var reading, writing : bool
if hs.initiator and direction == D_r:
# I'm Alice and direction is ->
reading = false
writing = true
elif hs.initiator and direction == D_l:
# I'm Alice and direction is <-
reading = true
writing = false
elif not hs.initiator and direction == D_r:
# I'm Bob and direction is ->
reading = true
writing = false
elif not hs.initiator and direction == D_l:
# I'm Bob and direction is <-
reading = false
writing = true
return (reading, writing)
# Checks if a pre-message is valid according to Noise specifications
# http://www.noiseprotocol.org/noise.html#handshake-patterns
proc isValid(msg: seq[PreMessagePattern]): bool =
var isValid: bool = true
# Non-empty pre-messages can only have patterns "e", "s", "e,s" in each direction
let allowedPatterns: seq[PreMessagePattern] = @[ PreMessagePattern(direction: D_r, tokens: @[T_s]),
PreMessagePattern(direction: D_r, tokens: @[T_e]),
PreMessagePattern(direction: D_r, tokens: @[T_e, T_s]),
PreMessagePattern(direction: D_l, tokens: @[T_s]),
PreMessagePattern(direction: D_l, tokens: @[T_e]),
PreMessagePattern(direction: D_l, tokens: @[T_e, T_s])
]
# We check if pre message patterns are allowed
for pattern in msg:
if not (pattern in allowedPatterns):
isValid = false
break
return isValid
#################################
# Handshake messages processing procedures
#################################
# Processes pre-message patterns
proc processPreMessagePatternTokens(hs: var HandshakeState, inPreMessagePKs: seq[NoisePublicKey] = @[])
{.raises: [Defect, NoiseMalformedHandshake, NoiseHandshakeError, NoisePublicKeyError].} =
var
# I make a copy of the input pre-message public keys, so that I can easily delete processed ones without using iterators/counters
preMessagePKs = inPreMessagePKs
# Here we store currently processed pre message public key
currPK : NoisePublicKey
# We retrieve the pre-message patterns to process, if any
# If none, there's nothing to do
if hs.handshakePattern.preMessagePatterns == EmptyPreMessage:
return
# If not empty, we check that pre-message is valid according to Noise specifications
if isValid(hs.handshakePattern.preMessagePatterns) == false:
raise newException(NoiseMalformedHandshake, "Invalid pre-message in handshake")
# We iterate over each pattern contained in the pre-message
for messagePattern in hs.handshakePattern.preMessagePatterns:
let
direction = messagePattern.direction
tokens = messagePattern.tokens
# We get if the user is reading or writing the current pre-message pattern
var (reading, writing) = getReadingWritingState(hs , direction)
# We process each message pattern token
for token in tokens:
# We process the pattern token
case token
of T_e:
# We expect an ephemeral key, so we attempt to read it (next PK to process will always be at index 0 of preMessagePKs)
if preMessagePKs.len > 0:
currPK = preMessagePKs[0]
else:
raise newException(NoiseHandshakeError, "Noise pre-message read e, expected a public key")
# If user is reading the "e" token
if reading:
trace "noise pre-message read e"
# We check if current key is encrypted or not. We assume pre-message public keys are all unencrypted on users' end
if currPK.flag == 0.uint8:
# Sets re and calls MixHash(re.public_key).
hs.re = intoCurve25519Key(currPK.pk)
hs.ss.mixHash(hs.re)
else:
raise newException(NoisePublicKeyError, "Noise read e, incorrect encryption flag for pre-message public key")
# If user is writing the "e" token
elif writing:
trace "noise pre-message write e"
# When writing, the user is sending a public key,
# We check that the public part corresponds to the set local key and we call MixHash(e.public_key).
if hs.e.publicKey == intoCurve25519Key(currPK.pk):
hs.ss.mixHash(hs.e.publicKey)
else:
raise newException(NoisePublicKeyError, "Noise pre-message e key doesn't correspond to locally set e key pair")
# Noise specification: section 9.2
# In non-PSK handshakes, the "e" token in a pre-message pattern or message pattern always results
# in a call to MixHash(e.public_key).
# In a PSK handshake, all of these calls are followed by MixKey(e.public_key).
if "psk" in hs.handshakePattern.name:
hs.ss.mixKey(currPK.pk)
# We delete processed public key
preMessagePKs.delete(0)
of T_s:
# We expect a static key, so we attempt to read it (next PK to process will always be at index of preMessagePKs)
if preMessagePKs.len > 0:
currPK = preMessagePKs[0]
else:
raise newException(NoiseHandshakeError, "Noise pre-message read s, expected a public key")
# If user is reading the "s" token
if reading:
trace "noise pre-message read s"
# We check if current key is encrypted or not. We assume pre-message public keys are all unencrypted on users' end
if currPK.flag == 0.uint8:
# Sets re and calls MixHash(re.public_key).
hs.rs = intoCurve25519Key(currPK.pk)
hs.ss.mixHash(hs.rs)
else:
raise newException(NoisePublicKeyError, "Noise read s, incorrect encryption flag for pre-message public key")
# If user is writing the "s" token
elif writing:
trace "noise pre-message write s"
# If writing, it means that the user is sending a public key,
# We check that the public part corresponds to the set local key and we call MixHash(s.public_key).
if hs.s.publicKey == intoCurve25519Key(currPK.pk):
hs.ss.mixHash(hs.s.publicKey)
else:
raise newException(NoisePublicKeyError, "Noise pre-message s key doesn't correspond to locally set s key pair")
# Noise specification: section 9.2
# In non-PSK handshakes, the "e" token in a pre-message pattern or message pattern always results
# in a call to MixHash(e.public_key).
# In a PSK handshake, all of these calls are followed by MixKey(e.public_key).
if "psk" in hs.handshakePattern.name:
hs.ss.mixKey(currPK.pk)
# We delete processed public key
preMessagePKs.delete(0)
else:
raise newException(NoiseMalformedHandshake, "Invalid Token for pre-message pattern")
# This procedure encrypts/decrypts the implicit payload attached at the end of every message pattern
proc processMessagePatternPayload(hs: var HandshakeState, transportMessage: seq[byte]): seq[byte]
{.raises: [Defect, NoiseDecryptTagError, NoiseNonceMaxError].} =
var payload: seq[byte]
# We retrieve current message pattern (direction + tokens) to process
let direction = hs.handshakePattern.messagePatterns[hs.msgPatternIdx].direction
# We get if the user is reading or writing the input handshake message
var (reading, writing) = getReadingWritingState(hs, direction)
# We decrypt the transportMessage, if any
if reading:
payload = hs.ss.decryptAndHash(transportMessage)
elif writing:
payload = hs.ss.encryptAndHash(transportMessage)
return payload
# We process an input handshake message according to current handshake state and we return the next handshake step's handshake message
proc processMessagePatternTokens(rng: var BrHmacDrbgContext, hs: var HandshakeState, inputHandshakeMessage: seq[NoisePublicKey] = @[]): Result[seq[NoisePublicKey], cstring]
{.raises: [Defect, NoiseHandshakeError, NoiseMalformedHandshake, NoisePublicKeyError, NoiseDecryptTagError, NoiseNonceMaxError].} =
# We retrieve current message pattern (direction + tokens) to process
let
messagePattern = hs.handshakePattern.messagePatterns[hs.msgPatternIdx]
direction = messagePattern.direction
tokens = messagePattern.tokens
# We get if the user is reading or writing the input handshake message
var (reading, writing) = getReadingWritingState(hs , direction)
# I make a copy of the handshake message so that I can easily delete processed PKs without using iterators/counters
# (Possibly) non-empty if reading
var inHandshakeMessage = inputHandshakeMessage
# The party's output public keys
# (Possibly) non-empty if writing
var outHandshakeMessage: seq[NoisePublicKey] = @[]
# In currPK we store the currently processed public key from the handshake message
var currPK: NoisePublicKey
# We process each message pattern token
for token in tokens:
case token
of T_e:
# If user is reading the "s" token
if reading:
trace "noise read e"
# We expect an ephemeral key, so we attempt to read it (next PK to process will always be at index 0 of preMessagePKs)
if inHandshakeMessage.len > 0:
currPK = inHandshakeMessage[0]
else:
raise newException(NoiseHandshakeError, "Noise read e, expected a public key")
# We check if current key is encrypted or not
# Note: by specification, ephemeral keys should always be unencrypted. But we support encrypted ones.
if currPK.flag == 0.uint8:
# Unencrypted Public Key
# Sets re and calls MixHash(re.public_key).
hs.re = intoCurve25519Key(currPK.pk)
hs.ss.mixHash(hs.re)
# The following is out of specification: we call decryptAndHash for encrypted ephemeral keys, similarly as happens for (encrypted) static keys
elif currPK.flag == 1.uint8:
# Encrypted public key
# Decrypts re, sets re and calls MixHash(re.public_key).
hs.re = intoCurve25519Key(hs.ss.decryptAndHash(currPK.pk))
else:
raise newException(NoisePublicKeyError, "Noise read e, incorrect encryption flag for public key")
# Noise specification: section 9.2
# In non-PSK handshakes, the "e" token in a pre-message pattern or message pattern always results
# in a call to MixHash(e.public_key).
# In a PSK handshake, all of these calls are followed by MixKey(e.public_key).
if "psk" in hs.handshakePattern.name:
hs.ss.mixKey(hs.re)
# We delete processed public key
inHandshakeMessage.delete(0)
# If user is writing the "e" token
elif writing:
trace "noise write e"
# We generate a new ephemeral keypair
hs.e = genKeyPair(rng)
# We update the state
hs.ss.mixHash(hs.e.publicKey)
# Noise specification: section 9.2
# In non-PSK handshakes, the "e" token in a pre-message pattern or message pattern always results
# in a call to MixHash(e.public_key).
# In a PSK handshake, all of these calls are followed by MixKey(e.public_key).
if "psk" in hs.handshakePattern.name:
hs.ss.mixKey(hs.e.publicKey)
# We add the ephemeral public key to the Waku payload
outHandshakeMessage.add toNoisePublicKey(getPublicKey(hs.e))
of T_s:
# If user is reading the "s" token
if reading:
trace "noise read s"
# We expect a static key, so we attempt to read it (next PK to process will always be at index 0 of preMessagePKs)
if inHandshakeMessage.len > 0:
currPK = inHandshakeMessage[0]
else:
raise newException(NoiseHandshakeError, "Noise read s, expected a public key")
# We check if current key is encrypted or not
if currPK.flag == 0.uint8:
# Unencrypted Public Key
# Sets re and calls MixHash(re.public_key).
hs.rs = intoCurve25519Key(currPK.pk)
hs.ss.mixHash(hs.rs)
elif currPK.flag == 1.uint8:
# Encrypted public key
# Decrypts rs, sets rs and calls MixHash(rs.public_key).
hs.rs = intoCurve25519Key(hs.ss.decryptAndHash(currPK.pk))
else:
raise newException(NoisePublicKeyError, "Noise read s, incorrect encryption flag for public key")
# We delete processed public key
inHandshakeMessage.delete(0)
# If user is writing the "s" token
elif writing:
trace "noise write s"
# If the local static key is not set (the handshake state was not properly initialized), we raise an error
if hs.s == default(KeyPair):
raise newException(NoisePublicKeyError, "Static key not set")
# We encrypt the public part of the static key in case a key is set in the Cipher State
# That is, encS may either be an encrypted or unencrypted static key.
let encS = hs.ss.encryptAndHash(hs.s.publicKey)
# We add the (encrypted) static public key to the Waku payload
# Note that encS = (Enc(s) || tag) if encryption key is set, otherwise encS = s.
# We distinguish these two cases by checking length of encryption and we set the proper encryption flag
if encS.len > Curve25519Key.len:
outHandshakeMessage.add NoisePublicKey(flag: 1, pk: encS)
else:
outHandshakeMessage.add NoisePublicKey(flag: 0, pk: encS)
of T_psk:
# If user is reading the "psk" token
trace "noise psk"
# Calls MixKeyAndHash(psk)
hs.ss.mixKeyAndHash(hs.psk)
of T_ee:
# If user is reading the "ee" token
trace "noise dh ee"
# If local and/or remote ephemeral keys are not set, we raise an error
if hs.e == default(KeyPair) or hs.re == default(Curve25519Key):
raise newException(NoisePublicKeyError, "Local or remote ephemeral key not set")
# Calls MixKey(DH(e, re)).
hs.ss.mixKey(dh(hs.e.privateKey, hs.re))
of T_es:
# If user is reading the "es" token
trace "noise dh es"
# We check if keys are correctly set.
# If both present, we call MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if responder.
if hs.initiator:
if hs.e == default(KeyPair) or hs.rs == default(Curve25519Key):
raise newException(NoisePublicKeyError, "Local or remote ephemeral/static key not set")
hs.ss.mixKey(dh(hs.e.privateKey, hs.rs))
else:
if hs.re == default(Curve25519Key) or hs.s == default(KeyPair):
raise newException(NoisePublicKeyError, "Local or remote ephemeral/static key not set")
hs.ss.mixKey(dh(hs.s.privateKey, hs.re))
of T_se:
# If user is reading the "se" token
trace "noise dh se"
# We check if keys are correctly set.
# If both present, call MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if responder.
if hs.initiator:
if hs.s == default(KeyPair) or hs.re == default(Curve25519Key):
raise newException(NoiseMalformedHandshake, "Local or remote ephemeral/static key not set")
hs.ss.mixKey(dh(hs.s.privateKey, hs.re))
else:
if hs.rs == default(Curve25519Key) or hs.e == default(KeyPair):
raise newException(NoiseMalformedHandshake, "Local or remote ephemeral/static key not set")
hs.ss.mixKey(dh(hs.e.privateKey, hs.rs))
of T_ss:
# If user is reading the "ss" token
trace "noise dh ss"
# If local and/or remote static keys are not set, we raise an error
if hs.s == default(KeyPair) or hs.rs == default(Curve25519Key):
raise newException(NoiseMalformedHandshake, "Local or remote static key not set")
# Calls MixKey(DH(s, rs)).
hs.ss.mixKey(dh(hs.s.privateKey, hs.rs))
return ok(outHandshakeMessage)
#################################
## Procedures to progress handshakes between users
#################################
# Initializes a Handshake State
proc initialize*(hsPattern: HandshakePattern, ephemeralKey: KeyPair = default(KeyPair), staticKey: KeyPair = default(KeyPair), prologue: seq[byte] = @[], psk: seq[byte] = @[], preMessagePKs: seq[NoisePublicKey] = @[], initiator: bool = false): HandshakeState
{.raises: [Defect, NoiseMalformedHandshake, NoiseHandshakeError, NoisePublicKeyError].} =
var hs = HandshakeState.init(hsPattern)
hs.ss.mixHash(prologue)
hs.e = ephemeralKey
hs.s = staticKey
hs.psk = psk
hs.msgPatternIdx = 0
hs.initiator = initiator
# We process any eventual handshake pre-message pattern by processing pre-message public keys
processPreMessagePatternTokens(hs, preMessagePKs)
return hs
# Advances 1 step in handshake
# Each user in a handshake alternates writing and reading of handshake messages.
# If the user is writing the handshake message, the transport message (if not empty) has to be passed to transportMessage and readPayloadV2 can be left to its default value
# It the user is reading the handshake message, the read payload v2 has to be passed to readPayloadV2 and the transportMessage can be left to its default values.
proc stepHandshake*(rng: var BrHmacDrbgContext, hs: var HandshakeState, readPayloadV2: PayloadV2 = default(PayloadV2), transportMessage: seq[byte] = @[]): Result[HandshakeStepResult, cstring]
{.raises: [Defect, NoiseHandshakeError, NoiseMalformedHandshake, NoisePublicKeyError, NoiseDecryptTagError, NoiseNonceMaxError].} =
var hsStepResult: HandshakeStepResult
# If there are no more message patterns left for processing
# we return an empty HandshakeStepResult
if hs.msgPatternIdx > uint8(hs.handshakePattern.messagePatterns.len - 1):
debug "stepHandshake called more times than the number of message patterns present in handshake"
return ok(hsStepResult)
# We process the next handshake message pattern
# We get if the user is reading or writing the input handshake message
let direction = hs.handshakePattern.messagePatterns[hs.msgPatternIdx].direction
var (reading, writing) = getReadingWritingState(hs, direction)
# If we write an answer at this handshake step
if writing:
# We initialize a payload v2 and we set proper protocol ID (if supported)
try:
hsStepResult.payload2.protocolId = PayloadV2ProtocolIDs[hs.handshakePattern.name]
except:
raise newException(NoiseMalformedHandshake, "Handshake Pattern not supported")
# We set the handshake and transport message
hsStepResult.payload2.handshakeMessage = processMessagePatternTokens(rng, hs).get()
hsStepResult.payload2.transportMessage = processMessagePatternPayload(hs, transportMessage)
# If we read an answer during this handshake step
elif reading:
# We process the read public keys and (eventually decrypt) the read transport message
let
readHandshakeMessage = readPayloadV2.handshakeMessage
readTransportMessage = readPayloadV2.transportMessage
# Since we only read, nothing meanigful (i.e. public keys) is returned
discard processMessagePatternTokens(rng, hs, readHandshakeMessage)
# We retrieve and store the (decrypted) received transport message
hsStepResult.transportMessage = processMessagePatternPayload(hs, readTransportMessage)
else:
raise newException(NoiseHandshakeError, "Handshake Error: neither writing or reading user")
# We increase the handshake state message pattern index to progress to next step
hs.msgPatternIdx += 1
return ok(hsStepResult)
# Finalizes the handshake by calling Split and assigning the proper Cipher States to users
proc finalizeHandshake*(hs: var HandshakeState): HandshakeResult =
var hsResult: HandshakeResult
## Noise specification, Section 5:
## Processing the final handshake message returns two CipherState objects,
## the first for encrypting transport messages from initiator to responder,
## and the second for messages in the other direction.
# We call Split()
let (cs1, cs2) = hs.ss.split()
# We assign the proper Cipher States
if hs.initiator:
hsResult.csOutbound = cs1
hsResult.csInbound = cs2
else:
hsResult.csOutbound = cs2
hsResult.csInbound = cs1
# We store the optional fields rs and h
hsResult.rs = hs.rs
hsResult.h = hs.ss.h
return hsResult
#################################
# After-handshake procedures
#################################
## Noise specification, Section 5:
## Transport messages are then encrypted and decrypted by calling EncryptWithAd()
## and DecryptWithAd() on the relevant CipherState with zero-length associated data.
## If DecryptWithAd() signals an error due to DECRYPT() failure, then the input message is discarded.
## The application may choose to delete the CipherState and terminate the session on such an error,
## or may continue to attempt communications. If EncryptWithAd() or DecryptWithAd() signal an error
## due to nonce exhaustion, then the application must delete the CipherState and terminate the session.
# Writes an encrypted message using the proper Cipher State
proc writeMessage*(hsr: var HandshakeResult, transportMessage: seq[byte]): PayloadV2
{.raises: [Defect, NoiseNonceMaxError].} =
var payload2: PayloadV2
# According to 35/WAKU2-NOISE RFC, no Handshake protocol information is sent when exchanging messages
# This correspond to setting protocol-id to 0
payload2.protocolId = 0.uint8
# Encryption is done with zero-length associated data as per specification
payload2.transportMessage = encryptWithAd(hsr.csOutbound, @[], transportMessage)
return payload2
# Reads an encrypted message using the proper Cipher State
# Associated data ad for encryption is optional, since the latter is out of scope for Noise
proc readMessage*(hsr: var HandshakeResult, readPayload2: PayloadV2): Result[seq[byte], cstring]
{.raises: [Defect, NoiseDecryptTagError, NoiseNonceMaxError].} =
# The output decrypted message
var message: seq[byte]
# According to 35/WAKU2-NOISE RFC, no Handshake protocol information is sent when exchanging messages
if readPayload2.protocolId == 0.uint8:
# On application level we decide to discard messages which fail decryption, without raising an error
# (this because an attacker may flood the content topic on which messages are exchanged)
try:
# Decryption is done with zero-length associated data as per specification
message = decryptWithAd(hsr.csInbound, @[], readPayload2.transportMessage)
except NoiseDecryptTagError:
debug "A read message failed decryption. Returning empty message as plaintext."
message = @[]
return ok(message)