diff --git a/tests/v2/test_waku_noise.nim b/tests/v2/test_waku_noise.nim index 60e479ef0..7893ef8a6 100644 --- a/tests/v2/test_waku_noise.nim +++ b/tests/v2/test_waku_noise.nim @@ -3,11 +3,15 @@ import testutils/unittests, std/random, + std/tables, stew/byteutils, ../../waku/v2/node/waku_payload, ../../waku/v2/protocol/waku_noise/noise, ../../waku/v2/protocol/waku_message, - ../test_helpers + ../test_helpers, + libp2p/crypto/chacha20poly1305, + stew/endians2 + procSuite "Waku Noise": @@ -157,4 +161,255 @@ procSuite "Waku Noise": check: decoded.isOk() - payload2 == decoded.get() \ No newline at end of file + payload2 == decoded.get() + + test "Noise State Machine: Diffie-Hellman operation": + + #We generate random keypairs + let + aliceKey = genKeyPair(rng[]) + bobKey = genKeyPair(rng[]) + + # A Diffie-Hellman operation between Alice's private key and Bob's public key must be equal to + # a Diffie-hellman operation between Alice's public key and Bob's private key + let + dh1 = dh(getPrivateKey(aliceKey), getPublicKey(bobKey)) + dh2 = dh(getPrivateKey(bobKey), getPublicKey(aliceKey)) + + check: + dh1 == dh2 + + test "Noise State Machine: Cipher State primitives": + + # We generate a random Cipher State, associated data ad and plaintext + var + cipherState: CipherState = randomCipherState(rng[]) + nonce: uint64 = uint64(rand(0 .. int.high)) + ad: seq[byte] = randomSeqByte(rng[], rand(1..128)) + plaintext: seq[byte] = randomSeqByte(rng[], rand(1..128)) + + # We set the random nonce generated in the cipher state + setNonce(cipherState, nonce) + + # We perform encryption + var ciphertext: seq[byte] = encryptWithAd(cipherState, ad, plaintext) + + # After any encryption/decryption operation, the Cipher State's nonce increases by 1 + check: + getNonce(cipherState) == nonce + 1 + + # We set the nonce back to its original value for decryption + setNonce(cipherState, nonce) + + # We decrypt (using the original nonce) + var decrypted: seq[byte] = decryptWithAd(cipherState, ad, ciphertext) + + # We check if encryption and decryption are correct and that nonce correctly increased after decryption + check: + getNonce(cipherState) == nonce + 1 + plaintext == decrypted + + + # If a Cipher State has no key set, encryptWithAd should return the plaintext without increasing the nonce + setCipherStateKey(cipherState, EmptyKey) + nonce = getNonce(cipherState) + + plaintext = randomSeqByte(rng[], rand(1..128)) + ciphertext = encryptWithAd(cipherState, ad, plaintext) + + check: + ciphertext == plaintext + getNonce(cipherState) == nonce + + # If a Cipher State has no key set, decryptWithAd should return the ciphertext without increasing the nonce + setCipherStateKey(cipherState, EmptyKey) + nonce = getNonce(cipherState) + + # Note that we set ciphertext minimum length to 16 to not trigger checks on authentication tag length + ciphertext = randomSeqByte(rng[], rand(16..128)) + plaintext = decryptWithAd(cipherState, ad, ciphertext) + + check: + ciphertext == plaintext + getNonce(cipherState) == nonce + + # A Cipher State cannot have a nonce greater or equal 2^64-1 + # Note that NonceMax is uint64.high - 1 = 2^64-1-1 and that nonce is increased after each encryption and decryption operation + + # We generate a test Cipher State with nonce set to MaxNonce + cipherState = randomCipherState(rng[]) + setNonce(cipherState, NonceMax) + plaintext = randomSeqByte(rng[], rand(1..128)) + + # We test if encryption fails with a NoiseNonceMaxError error. Any subsequent encryption call over the Cipher State should fail similarly and leave the nonce unchanged + for _ in [1..5]: + expect NoiseNonceMaxError: + ciphertext = encryptWithAd(cipherState, ad, plaintext) + + check: + getNonce(cipherState) == NonceMax + 1 + + # We generate a test Cipher State + # Since nonce is increased after decryption as well, we need to generate a proper ciphertext in order to test MaxNonceError error handling + # We cannot call encryptWithAd to encrypt a plaintext using a nonce equal MaxNonce, since this will trigger a MaxNonceError. + # To perform such test, we then need to encrypt a test plaintext using directly ChaChaPoly primitive + cipherState = randomCipherState(rng[]) + setNonce(cipherState, NonceMax) + plaintext = randomSeqByte(rng[], rand(1..128)) + + # We perform encryption using the Cipher State key, NonceMax and ad + # By Noise specification the nonce is 8 bytes long out of the 12 bytes supported by ChaChaPoly, thus we copy the Little endian conversion of the nonce to a ChaChaPolyNonce + var + encNonce: ChaChaPolyNonce + authorizationTag: ChaChaPolyTag + encNonce[4..<12] = toBytesLE(NonceMax) + ChaChaPoly.encrypt(getKey(cipherState), encNonce, authorizationTag, plaintext, ad) + + # The output ciphertext is stored in the plaintext variable after ChaChaPoly.encrypt is called: we copy it along with the authorization tag. + ciphertext = @[] + ciphertext.add(plaintext) + ciphertext.add(authorizationTag) + + # At this point ciphertext is a proper encryption of the original plaintext obtained with nonce equal to NonceMax + # We can now test if decryption fails with a NoiseNonceMaxError error. Any subsequent decryption call over the Cipher State should fail similarly and leave the nonce unchanged + # Note that decryptWithAd doesn't fail in decrypting the ciphertext (otherwise a NoiseDecryptTagError would have been triggered) + for _ in [1..5]: + expect NoiseNonceMaxError: + plaintext = decryptWithAd(cipherState, ad, ciphertext) + + check: + getNonce(cipherState) == NonceMax + 1 + + test "Noise State Machine: Symmetric State primitives": + + # We select one supported handshake pattern and we initialize a symmetric state + var + hsPattern = NoiseHandshakePatterns["XX"] + symmetricState: SymmetricState = SymmetricState.init(hsPattern) + + # We get all the Symmetric State field + # cs : Cipher State + # ck : chaining key + # h : handshake hash + var + cs = getCipherState(symmetricState) + ck = getChainingKey(symmetricState) + h = getHandshakeHash(symmetricState) + + # When a Symmetric state is initialized, handshake hash and chaining key are (byte-wise) equal + check: + h.data.intoChaChaPolyKey == ck + + ######################################## + # mixHash + ######################################## + + # We generate a random byte sequence and execute a mixHash over it + mixHash(symmetricState, randomSeqByte(rng[], rand(1..128))) + + # mixHash changes only the handshake hash value of the Symmetric state + check: + cs == getCipherState(symmetricState) + ck == getChainingKey(symmetricState) + h != getHandshakeHash(symmetricState) + + # We update test values + h = getHandshakeHash(symmetricState) + + ######################################## + # mixKey + ######################################## + + # We generate random input key material and we execute mixKey + var inputKeyMaterial = randomChaChaPolyKey(rng[]) + mixKey(symmetricState, inputKeyMaterial) + + # mixKey changes the Symmetric State's chaining key and encryption key of the embedded Cipher State + # It further sets to 0 the nonce of the embedded Cipher State + check: + getKey(cs) != getKey(getCipherState(symmetricState)) + getNonce(getCipherState(symmetricState)) == 0.uint64 + cs != getCipherState(symmetricState) + ck != getChainingKey(symmetricState) + h == getHandshakeHash(symmetricState) + + # We update test values + cs = getCipherState(symmetricState) + ck = getChainingKey(symmetricState) + + ######################################## + # mixKeyAndHash + ######################################## + + # We generate random input key material and we execute mixKeyAndHash + inputKeyMaterial = randomChaChaPolyKey(rng[]) + mixKeyAndHash(symmetricState, inputKeyMaterial) + + # mixKeyAndHash executes a mixKey and a mixHash using the input key material + # All Symmetric State's fields are updated + check: + cs != getCipherState(symmetricState) + ck != getChainingKey(symmetricState) + h != getHandshakeHash(symmetricState) + + # We update test values + cs = getCipherState(symmetricState) + ck = getChainingKey(symmetricState) + h = getHandshakeHash(symmetricState) + + ######################################## + # encryptAndHash and decryptAndHash + ######################################## + + # We store the initial symmetricState in order to correctly perform decryption + var initialSymmetricState = symmetricState + + # We generate random plaintext and we execute encryptAndHash + var plaintext = randomChaChaPolyKey(rng[]) + var nonce = getNonce(getCipherState(symmetricState)) + var ciphertext = encryptAndHash(symmetricState, plaintext) + + # encryptAndHash combines encryptWithAd and mixHash over the ciphertext (encryption increases the nonce of the embedded Cipher State but does not change its key) + # We check if only the handshake hash value and the Symmetric State changed accordingly + check: + cs != getCipherState(symmetricState) + getKey(cs) == getKey(getCipherState(symmetricState)) + getNonce(getCipherState(symmetricState)) == nonce + 1 + ck == getChainingKey(symmetricState) + h != getHandshakeHash(symmetricState) + + # We restore the symmetric State to its initial value to test decryption + symmetricState = initialSymmetricState + + # We execute decryptAndHash over the ciphertext + var decrypted = decryptAndHash(symmetricState, ciphertext) + + # decryptAndHash combines decryptWithAd and mixHash over the ciphertext (encryption increases the nonce of the embedded Cipher State but does not change its key) + # We check if only the handshake hash value and the Symmetric State changed accordingly + # We further check if decryption corresponds to the original plaintext + check: + cs != getCipherState(symmetricState) + getKey(cs) == getKey(getCipherState(symmetricState)) + getNonce(getCipherState(symmetricState)) == nonce + 1 + ck == getChainingKey(symmetricState) + h != getHandshakeHash(symmetricState) + decrypted == plaintext + + ######################################## + # split + ######################################## + + # If at least one mixKey is executed (as above), ck is non-empty + check: + getChainingKey(symmetricState) != EmptyKey + + # When a Symmetric State's ck is non-empty, we can execute split, which creates two distinct Cipher States cs1 and cs2 + # with non-empty encryption keys and nonce set to 0 + var (cs1, cs2) = split(symmetricState) + + check: + getKey(cs1) != EmptyKey + getKey(cs2) != EmptyKey + getNonce(cs1) == 0.uint64 + getNonce(cs2) == 0.uint64 + getKey(cs1) != getKey(cs2) diff --git a/waku/v2/protocol/waku_noise/noise.nim b/waku/v2/protocol/waku_noise/noise.nim index c9f91dd03..e41c746e7 100644 --- a/waku/v2/protocol/waku_noise/noise.nim +++ b/waku/v2/protocol/waku_noise/noise.nim @@ -7,16 +7,17 @@ {.push raises: [Defect].} -import std/[options, tables, strutils] +import std/[oids, options, strutils, tables] import chronos import chronicles import bearssl -import stew/[results, endians2] +import stew/[results, endians2, byteutils] import nimcrypto/[utils, sha2, hmac] import libp2p/utility import libp2p/errors -import libp2p/crypto/[crypto, chacha20poly1305, curve25519] +import libp2p/crypto/[crypto, chacha20poly1305, curve25519, hkdf] +import libp2p/protocols/secure/secure logScope: @@ -28,13 +29,19 @@ logScope: const # EmptyKey represents a non-initialized ChaChaPolyKey - EmptyKey = default(ChaChaPolyKey) + EmptyKey* = default(ChaChaPolyKey) # The maximum ChaChaPoly allowed nonce in Noise Handshakes - NonceMax = uint64.high - 1 + NonceMax* = uint64.high - 1 type + + ################################# + # Elliptic Curve arithemtic + ################################# + # Default underlying elliptic curve arithmetic (useful for switching to multiple ECs) # Current default is Curve25519 + EllipticCurve = Curve25519 EllipticCurveKey = Curve25519Key # An EllipticCurveKey (public, private) key pair @@ -42,6 +49,10 @@ type privateKey: EllipticCurveKey publicKey: EllipticCurveKey + ################################# + # Noise Public Keys + ################################# + # A Noise public key is a public key exchanged during Noise handshakes (no private part) # This follows https://rfc.vac.dev/spec/35/#public-keys-serialization # pk contains the X coordinate of the public key, if unencrypted (this implies flag = 0) @@ -51,6 +62,10 @@ type flag: uint8 pk: seq[byte] + ################################# + # ChaChaPoly Encryption + ################################# + # A ChaChaPoly ciphertext (data) + authorization tag (tag) ChaChaPolyCiphertext* = object data*: seq[byte] @@ -62,6 +77,96 @@ type nonce: ChaChaPolyNonce ad: seq[byte] + ################################# + # Noise handshake patterns + ################################# + + # The Noise tokens appearing in Noise (pre)message patterns + # as in http://www.noiseprotocol.org/noise.html#handshake-pattern-basics + NoiseTokens = enum + T_e = "e" + T_s = "s" + T_es = "es" + T_ee = "ee" + T_se = "se" + T_ss = "se" + T_psk = "psk" + + # The direction of a (pre)message pattern in canonical form (i.e. Alice-initiated form) + # as in http://www.noiseprotocol.org/noise.html#alice-and-bob + MessageDirection* = enum + D_r = "->" + D_l = "<-" + + # The pre message pattern consisting of a message direction and some Noise tokens, if any. + # (if non empty, only tokens e and s are allowed: http://www.noiseprotocol.org/noise.html#handshake-pattern-basics) + PreMessagePattern* = object + direction: MessageDirection + tokens: seq[NoiseTokens] + + # The message pattern consisting of a message direction and some Noise tokens + # All Noise tokens are allowed + MessagePattern* = object + direction: MessageDirection + tokens: seq[NoiseTokens] + + # The handshake pattern object. It stores the handshake protocol name, the handshake pre message patterns and the handshake message patterns + HandshakePattern* = object + name*: string + preMessagePatterns*: seq[PreMessagePattern] + messagePatterns*: seq[MessagePattern] + + ################################# + # Noise state machine + ################################# + + # The Cipher State as in https://noiseprotocol.org/noise.html#the-cipherstate-object + # Contains an encryption key k and a nonce n (used in Noise as a counter) + CipherState* = object + k: ChaChaPolyKey + n: uint64 + + # The Symmetric State as in https://noiseprotocol.org/noise.html#the-symmetricstate-object + # Contains a Cipher State cs, the chaining key ck and the handshake hash value h + SymmetricState* = object + cs: CipherState + ck: ChaChaPolyKey + h: MDigest[256] + + # The Handshake State as in https://noiseprotocol.org/noise.html#the-handshakestate-object + # Contains + # - the local and remote ephemeral/static keys e,s,re,rs (if any) + # - the initiator flag (true if the user creating the state is the handshake initiator, false otherwise) + # - the handshakePattern (containing the handshake protocol name, and (pre)message patterns) + # This object is futher extended from specifications by storing: + # - a message pattern index msgPatternIdx indicating the next handshake message pattern to process + # - the user's preshared psk, if any + HandshakeState = object + s: KeyPair + e: KeyPair + rs: EllipticCurveKey + re: EllipticCurveKey + ss: SymmetricState + initiator: bool + handshakePattern: HandshakePattern + msgPatternIdx: uint8 + psk: seq[byte] + + # When a handshake is complete, the HandhshakeResult will contain the two + # Cipher States used to encrypt/decrypt outbound/inbound messages + # The recipient static key rs and handshake hash values h are stored to address some possible future applications (channel-binding, session management, etc.). + # However, are not required by Noise specifications and are thus optional + HandshakeResult = object + csInbound: CipherState + csOutbound: CipherState + # Optional fields: + rs: EllipticCurveKey + h: MDigest[256] + + ################################# + # Waku Payload V2 + ################################# + # PayloadV2 defines an object for Waku payloads with version 2 as in # https://rfc.vac.dev/spec/35/#public-keys-serialization # It contains a protocol ID field, the handshake message (for Noise handshakes) and @@ -71,7 +176,10 @@ type handshakeMessage: seq[NoisePublicKey] transportMessage: seq[byte] + ################################# # Some useful error types + ################################# + NoiseError* = object of LPError NoiseHandshakeError* = object of NoiseError NoiseEmptyChaChaPolyInput* = object of NoiseError @@ -81,9 +189,66 @@ type NoiseMalformedHandshake* = object of NoiseError +################################# +# Constants (supported protocols) +################################# +const + + # The empty pre message patterns + EmptyPreMessagePattern: seq[PreMessagePattern] = @[] + + # Supported Noise handshake patterns as defined in https://rfc.vac.dev/spec/35/#specification + NoiseHandshakePatterns* = { + "K1K1": HandshakePattern(name: "Noise_K1K1_25519_ChaChaPoly_SHA256", + preMessagePatterns: @[PreMessagePattern(direction: D_r, tokens: @[T_s]), + PreMessagePattern(direction: D_l, tokens: @[T_s])], + messagePatterns: @[ MessagePattern(direction: D_r, tokens: @[T_e]), + MessagePattern(direction: D_l, tokens: @[T_e, T_ee, T_es]), + MessagePattern(direction: D_r, tokens: @[T_se])] + ), + + "XK1": HandshakePattern(name: "Noise_XK1_25519_ChaChaPoly_SHA256", + preMessagePatterns: @[PreMessagePattern(direction: D_l, tokens: @[T_s])], + messagePatterns: @[ MessagePattern(direction: D_r, tokens: @[T_e]), + MessagePattern(direction: D_l, tokens: @[T_e, T_ee, T_es]), + MessagePattern(direction: D_r, tokens: @[T_s, T_se])] + ), + + "XX": HandshakePattern(name: "Noise_XX_25519_ChaChaPoly_SHA256", + preMessagePatterns: EmptyPreMessagePattern, + messagePatterns: @[ MessagePattern(direction: D_r, tokens: @[T_e]), + MessagePattern(direction: D_l, tokens: @[T_e, T_ee, T_s, T_es]), + MessagePattern(direction: D_r, tokens: @[T_s, T_se])] + ), + + "XXpsk0": HandshakePattern(name: "Noise_XXpsk0_25519_ChaChaPoly_SHA256", + preMessagePatterns: EmptyPreMessagePattern, + messagePatterns: @[ MessagePattern(direction: D_r, tokens: @[T_psk, T_e]), + MessagePattern(direction: D_l, tokens: @[T_e, T_ee, T_s, T_es]), + MessagePattern(direction: D_r, tokens: @[T_s, T_se])] + ) + }.toTable() + + + # Supported Protocol ID for PayloadV2 objects + # Protocol IDs are defined according to https://rfc.vac.dev/spec/35/#specification + PayloadV2ProtocolIDs* = { + + "": 0.uint8, + "Noise_K1K1_25519_ChaChaPoly_SHA256": 10.uint8, + "Noise_XK1_25519_ChaChaPoly_SHA256": 11.uint8, + "Noise_XX_25519_ChaChaPoly_SHA256": 12.uint8, + "Noise_XXpsk0_25519_ChaChaPoly_SHA256": 13.uint8, + "ChaChaPoly": 30.uint8 + + }.toTable() + + ################################################################# +################################# # Utilities +################################# # Generates random byte sequences of given size proc randomSeqByte*(rng: var BrHmacDrbgContext, size: int): seq[byte] = @@ -98,10 +263,354 @@ proc genKeyPair*(rng: var BrHmacDrbgContext): KeyPair = keyPair.publicKey = keyPair.privateKey.public() return keyPair +# Gets private key from a key pair +proc getPrivateKey*(keypair: KeyPair): EllipticCurveKey = + return keypair.privateKey + +# Gets public key from a key pair +proc getPublicKey*(keypair: KeyPair): EllipticCurveKey = + return keypair.publicKey + +# Prints Handshake Patterns using Noise pattern layout +proc print*(self: HandshakePattern) + {.raises: [IOError, NoiseMalformedHandshake].}= + try: + if self.name != "": + stdout.write self.name, ":\n" + stdout.flushFile() + #We iterate over pre message patterns, if any + if self.preMessagePatterns != EmptyPreMessagePattern: + for pattern in self.preMessagePatterns: + stdout.write " ", pattern.direction + var first = true + for token in pattern.tokens: + if first: + stdout.write " ", token + first = false + else: + stdout.write ", ", token + stdout.write "\n" + stdout.flushFile() + stdout.write " ...\n" + stdout.flushFile() + #We iterate over message patterns + for pattern in self.messagePatterns: + stdout.write " ", pattern.direction + var first = true + for token in pattern.tokens: + if first: + stdout.write " ", token + first = false + else: + stdout.write ", ", token + stdout.write "\n" + stdout.flushFile() + except: + raise newException(NoiseMalformedHandshake, "HandshakePattern malformed") + +# Hashes a Noise protocol name using SHA256 +proc hashProtocol(protocolName: string): MDigest[256] = + + # The output hash value + var hash: MDigest[256] + + # From Noise specification: Section 5.2 + # http://www.noiseprotocol.org/noise.html#the-symmetricstate-object + # If protocol_name is less than or equal to HASHLEN bytes in length, + # sets h equal to protocol_name with zero bytes appended to make HASHLEN bytes. + # Otherwise sets h = HASH(protocol_name). + if protocolName.len <= 32: + hash.data[0..protocolName.high] = protocolName.toBytes + else: + hash = sha256.digest(protocolName) + + return hash + +# Performs a Diffie-Hellman operation between two elliptic curve keys (one private, one public) +proc dh*(private: EllipticCurveKey, public: EllipticCurveKey): EllipticCurveKey = + + # The output result of the Diffie-Hellman operation + var output: EllipticCurveKey + + # Since the EC multiplication writes the result to the input, we copy the input to the output variable + output = public + # We execute the DH operation + EllipticCurve.mul(output, private) + + return output + ################################################################# +# Noise state machine primitives + +# Overview : +# - Alice and Bob process (i.e. read and write, based on their role) each token appearing in a handshake pattern, consisting of pre-message and message patterns; +# - Both users initialize and update according to processed tokens a Handshake State, a Symmetric State and a Cipher State; +# - A preshared key psk is processed by calling MixKeyAndHash(psk); +# - When an ephemeral public key e is read or written, the handshake hash value h is updated by calling mixHash(e); If the handshake expects a psk, MixKey(e) is further called +# - When an encrypted static public key s or a payload message m is read, it is decrypted with decryptAndHash; +# - When a static public key s or a payload message is writted, it is encrypted with encryptAndHash; +# - When any Diffie-Hellman token ee, es, se, ss is read or written, the chaining key ck is updated by calling MixKey on the computed secret; +# - If all tokens are processed, users compute two new Cipher States by calling Split; +# - The two Cipher States obtained from Split are used to encrypt/decrypt outbound/inbound messages. + +################################# +# Cipher State Primitives +################################# + +# Checks if a Cipher State has an encryption key set +proc hasKey(cs: CipherState): bool = + return (cs.k != EmptyKey) + +# Encrypts a plaintext using key material in a Noise Cipher State +# The CipherState is updated increasing the nonce (used as a counter in Noise) by one +proc encryptWithAd*(state: var CipherState, ad, plaintext: openArray[byte]): seq[byte] + {.raises: [Defect, NoiseNonceMaxError].} = + + # We raise an error if encryption is called using a Cipher State with nonce greater than MaxNonce + if state.n > NonceMax: + raise newException(NoiseNonceMaxError, "Noise max nonce value reached") + + var ciphertext: seq[byte] + + # If an encryption key is set in the Cipher state, we proceed with encryption + if state.hasKey: + + # The output is the concatenation of the ciphertext and authorization tag + # We define its length accordingly + ciphertext = newSeqOfCap[byte](plaintext.len + sizeof(ChaChaPolyTag)) + + # Since ChaChaPoly encryption primitive overwrites the input with the output, + # we copy the plaintext in the output ciphertext variable and we pass it to encryption + ciphertext.add(plaintext) + + # The nonce is read from the input CipherState + # By Noise specification the nonce is 8 bytes long out of the 12 bytes supported by ChaChaPoly + var nonce: ChaChaPolyNonce + nonce[4..<12] = toBytesLE(state.n) + + # We perform encryption and we store the authorization tag + var authorizationTag: ChaChaPolyTag + ChaChaPoly.encrypt(state.k, nonce, authorizationTag, ciphertext, ad) + + # We append the authorization tag to ciphertext + ciphertext.add(authorizationTag) + + # We increase the Cipher state nonce + inc state.n + # If the nonce is greater than the maximum allowed nonce, we raise an exception + if state.n > NonceMax: + raise newException(NoiseNonceMaxError, "Noise max nonce value reached") + + trace "encryptWithAd", authorizationTag = byteutils.toHex(authorizationTag), ciphertext = ciphertext, nonce = state.n - 1 + + # Otherwise we return the input plaintext according to specification http://www.noiseprotocol.org/noise.html#the-cipherstate-object + else: + + ciphertext = @plaintext + debug "encryptWithAd called with no encryption key set. Returning plaintext." + + return ciphertext + +# Decrypts a ciphertext using key material in a Noise Cipher State +# The CipherState is updated increasing the nonce (used as a counter in Noise) by one +proc decryptWithAd*(state: var CipherState, ad, ciphertext: openArray[byte]): seq[byte] + {.raises: [Defect, NoiseDecryptTagError, NoiseNonceMaxError].} = + + # We raise an error if encryption is called using a Cipher State with nonce greater than MaxNonce + if state.n > NonceMax: + raise newException(NoiseNonceMaxError, "Noise max nonce value reached") + + var plaintext: seq[byte] + + # If an encryption key is set in the Cipher state, we proceed with decryption + if state.hasKey: + + # We read the authorization appendend at the end of a ciphertext + let inputAuthorizationTag = ciphertext.toOpenArray(ciphertext.len - ChaChaPolyTag.len, ciphertext.high).intoChaChaPolyTag + + var + authorizationTag: ChaChaPolyTag + nonce: ChaChaPolyNonce + + # The nonce is read from the input CipherState + # By Noise specification the nonce is 8 bytes long out of the 12 bytes supported by ChaChaPoly + nonce[4..<12] = toBytesLE(state.n) + + # Since ChaChaPoly decryption primitive overwrites the input with the output, + # we copy the ciphertext (authorization tag excluded) in the output plaintext variable and we pass it to decryption + plaintext = ciphertext[0..(ciphertext.high - ChaChaPolyTag.len)] + + ChaChaPoly.decrypt(state.k, nonce, authorizationTag, plaintext, ad) + + # We check if the input authorization tag matches the decryption authorization tag + if inputAuthorizationTag != authorizationTag: + debug "decryptWithAd failed", plaintext = plaintext, ciphertext = ciphertext, inputAuthorizationTag = inputAuthorizationTag, authorizationTag = authorizationTag + raise newException(NoiseDecryptTagError, "decryptWithAd failed tag authentication.") + + # We increase the Cipher state nonce + inc state.n + # If the nonce is greater than the maximum allowed nonce, we raise an exception + if state.n > NonceMax: + raise newException(NoiseNonceMaxError, "Noise max nonce value reached") + + trace "decryptWithAd", inputAuthorizationTag = inputAuthorizationTag, authorizationTag = authorizationTag, nonce = state.n + + # Otherwise we return the input ciphertext according to specification http://www.noiseprotocol.org/noise.html#the-cipherstate-object + else: + + plaintext = @ciphertext + debug "decryptWithAd called with no encryption key set. Returning ciphertext." + + return plaintext + +# Sets the nonce of a Cipher State +proc setNonce*(cs: var CipherState, nonce: uint64) = + cs.n = nonce + +# Sets the key of a Cipher State +proc setCipherStateKey*(cs: var CipherState, key: ChaChaPolyKey) = + cs.k = key + +# Generates a random Symmetric Cipher State for test purposes +proc randomCipherState*(rng: var BrHmacDrbgContext, nonce: uint64 = 0): CipherState = + var randomCipherState: CipherState + brHmacDrbgGenerate(rng, randomCipherState.k) + setNonce(randomCipherState, nonce) + return randomCipherState + + +# Gets the key of a Cipher State +proc getKey*(cs: CipherState): ChaChaPolyKey = + return cs.k + +# Gets the nonce of a Cipher State +proc getNonce*(cs: CipherState): uint64 = + return cs.n + +################################# +# Symmetric State primitives +################################# + +# Initializes a Symmetric State +proc init*(_: type[SymmetricState], hsPattern: HandshakePattern): SymmetricState = + var ss: SymmetricState + # We compute the hash of the protocol name + ss.h = hsPattern.name.hashProtocol + # We initialize the chaining key ck + ss.ck = ss.h.data.intoChaChaPolyKey + # We initialize the Cipher state + ss.cs = CipherState(k: EmptyKey) + return ss + +# MixKey as per Noise specification http://www.noiseprotocol.org/noise.html#the-symmetricstate-object +# Updates a Symmetric state chaining key and symmetric state +proc mixKey*(ss: var SymmetricState, inputKeyMaterial: ChaChaPolyKey) = + # We derive two keys using HKDF + var tempKeys: array[2, ChaChaPolyKey] + sha256.hkdf(ss.ck, inputKeyMaterial, [], tempKeys) + # We update ck and the Cipher state's key k using the output of HDKF + ss.ck = tempKeys[0] + ss.cs = CipherState(k: tempKeys[1]) + trace "mixKey", ck = ss.ck, k = ss.cs.k + +# MixHash as per Noise specification http://www.noiseprotocol.org/noise.html#the-symmetricstate-object +# Hashes data into a Symmetric State's handshake hash value h +proc mixHash*(ss: var SymmetricState, data: openArray[byte]) = + # We prepare the hash context + var ctx: sha256 + ctx.init() + # We add the previous handshake hash + ctx.update(ss.h.data) + # We append the input data + ctx.update(data) + # We hash and store the result in the Symmetric State's handshake hash value + ss.h = ctx.finish() + trace "mixHash", hash = ss.h.data + +# mixKeyAndHash as per Noise specification http://www.noiseprotocol.org/noise.html#the-symmetricstate-object +# Combines MixKey and MixHash +proc mixKeyAndHash*(ss: var SymmetricState, inputKeyMaterial: openArray[byte]) {.used.} = + var tempKeys: array[3, ChaChaPolyKey] + # Derives 3 keys using HKDF, the chaining key and the input key material + sha256.hkdf(ss.ck, inputKeyMaterial, [], tempKeys) + # Sets the chaining key + ss.ck = tempKeys[0] + # Updates the handshake hash value + ss.mixHash(tempKeys[1]) + # Updates the Cipher state's key + # Note for later support of 512 bits hash functions: "If HASHLEN is 64, then truncates tempKeys[2] to 32 bytes." + ss.cs = CipherState(k: tempKeys[2]) + +# EncryptAndHash as per Noise specification http://www.noiseprotocol.org/noise.html#the-symmetricstate-object +# Combines encryptWithAd and mixHash +proc encryptAndHash*(ss: var SymmetricState, plaintext: openArray[byte]): seq[byte] + {.raises: [Defect, NoiseNonceMaxError].} = + # The output ciphertext + var ciphertext: seq[byte] + # Note that if an encryption key is not set yet in the Cipher state, ciphertext will be equal to plaintex + ciphertext = ss.cs.encryptWithAd(ss.h.data, plaintext) + # We call mixHash over the result + ss.mixHash(ciphertext) + return ciphertext + +# DecryptAndHash as per Noise specification http://www.noiseprotocol.org/noise.html#the-symmetricstate-object +# Combines decryptWithAd and mixHash +proc decryptAndHash*(ss: var SymmetricState, ciphertext: openArray[byte]): seq[byte] + {.raises: [Defect, NoiseDecryptTagError, NoiseNonceMaxError].} = + # The output plaintext + var plaintext: seq[byte] + # Note that if an encryption key is not set yet in the Cipher state, plaintext will be equal to ciphertext + plaintext = ss.cs.decryptWithAd(ss.h.data, ciphertext) + # According to specification, the ciphertext enters mixHash (and not the plaintext) + ss.mixHash(ciphertext) + return plaintext + +# Split as per Noise specification http://www.noiseprotocol.org/noise.html#the-symmetricstate-object +# Once a handshake is complete, returns two Cipher States to encrypt/decrypt outbound/inbound messages +proc split*(ss: var SymmetricState): tuple[cs1, cs2: CipherState] = + # Derives 2 keys using HKDF and the chaining key + var tempKeys: array[2, ChaChaPolyKey] + sha256.hkdf(ss.ck, [], [], tempKeys) + # Returns a tuple of two Cipher States initialized with the derived keys + return (CipherState(k: tempKeys[0]), CipherState(k: tempKeys[1])) + +# Gets the chaining key field of a Symmetric State +proc getChainingKey*(ss: SymmetricState): ChaChaPolyKey = + return ss.ck + +# Gets the handshake hash field of a Symmetric State +proc getHandshakeHash*(ss: SymmetricState): MDigest[256] = + return ss.h + +# Gets the Cipher State field of a Symmetric State +proc getCipherState*(ss: SymmetricState): CipherState = + return ss.cs + +################################# +# Handshake State primitives +################################# + +# Initializes a Handshake State +proc init*(_: type[HandshakeState], hsPattern: HandshakePattern, psk: seq[byte] = @[]): HandshakeState = + # The output Handshake State + var hs: HandshakeState + # By default the Handshake State initiator flag is set to false + # Will be set to true when the user associated to the handshake state starts an handshake + hs.initiator = false + # We copy the information on the handshake pattern for which the state is initialized (protocol name, handshake pattern, psk) + hs.handshakePattern = hsPattern + hs.psk = psk + # We initialize the Symmetric State + hs.ss = SymmetricState.init(hsPattern) + return hs + +################################################################# + +################################# # ChaChaPoly Symmetric Cipher +################################# # ChaChaPoly encryption # It takes a Cipher State (with key, nonce, and associated data) and encrypts a plaintext @@ -145,17 +654,23 @@ proc decrypt*( # the ciphertext (overwritten to plaintext) (data), the associated data (ad) ChaChaPoly.decrypt(state.k, state.nonce, tagOut, plaintext, state.ad) #TODO: add unpadding - trace "decrypt", tagIn = tagIn.shortLog, tagOut = tagOut.shortLog, nonce = state.nonce + trace "decrypt", tagIn = tagIn, tagOut = tagOut, nonce = state.nonce # We check if the authorization tag computed while decrypting is the same as the input tag if tagIn != tagOut: debug "decrypt failed", plaintext = shortLog(plaintext) raise newException(NoiseDecryptTagError, "decrypt tag authentication failed.") return plaintext +# Generates a random ChaChaPolyKey for testing encryption/decryption +proc randomChaChaPolyKey*(rng: var BrHmacDrbgContext): ChaChaPolyKey = + var key: ChaChaPolyKey + brHmacDrbgGenerate(rng, key) + return key + # Generates a random ChaChaPoly Cipher State for testing encryption/decryption proc randomChaChaPolyCipherState*(rng: var BrHmacDrbgContext): ChaChaPolyCipherState = var randomCipherState: ChaChaPolyCipherState - brHmacDrbgGenerate(rng, randomCipherState.k) + randomCipherState.k = randomChaChaPolyKey(rng) brHmacDrbgGenerate(rng, randomCipherState.nonce) randomCipherState.ad = newSeq[byte](32) brHmacDrbgGenerate(rng, randomCipherState.ad) @@ -164,7 +679,9 @@ proc randomChaChaPolyCipherState*(rng: var BrHmacDrbgContext): ChaChaPolyCipherS ################################################################# +################################# # Noise Public keys +################################# # Checks equality between two Noise public keys proc `==`(k1, k2: NoisePublicKey): bool = @@ -255,10 +772,11 @@ proc decryptNoisePublicKey*(cs: ChaChaPolyCipherState, noisePublicKey: NoisePubl decryptedNoisePublicKey = noisePublicKey return decryptedNoisePublicKey - ################################################################# +################################# # Payload encoding/decoding procedures +################################# # Checks equality between two PayloadsV2 objects proc `==`(p1, p2: PayloadV2): bool = @@ -334,7 +852,6 @@ proc serializePayloadV2*(self: PayloadV2): Result[seq[byte], cstring] = return ok(payload) - # Deserializes a byte sequence to a PayloadV2 object according to https://rfc.vac.dev/spec/35/. # The input serialized payload concatenates the output PayloadV2 object fields as # payload = ( protocolId || serializedHandshakeMessageLen || serializedHandshakeMessage || transportMessageLen || transportMessage) @@ -376,7 +893,7 @@ proc deserializePayloadV2*(payload: seq[byte]): Result[PayloadV2, cstring] # If the key is unencrypted, we only read the X coordinate of the EC public key and we deserialize into a Noise Public Key if flag == 0: pkLen = 1 + EllipticCurveKey.len - handshake_message.add intoNoisePublicKey(payload[i..