feat(noise): add Noise Handshake State Machine and primitives

This commit is contained in:
s1fr0 2022-03-31 00:14:13 +02:00
parent 07cf184336
commit 77e20c33be
No known key found for this signature in database
GPG Key ID: 2C041D60117BFF46
2 changed files with 288 additions and 14 deletions

View File

@ -7,7 +7,7 @@ import
../protocol/waku_message,
../protocol/waku_noise/noise
import libp2p/crypto/[chacha20poly1305, curve25519]
import libp2p/crypto/[curve25519]
export whisper_types, keys, options

View File

@ -9,21 +9,22 @@
{.push raises: [Defect].}
import std/[oids, options, tables]
import std/[oids, strformat, options, math, tables]
import chronos
import chronicles
import bearssl
import strutils
import stew/[endians2]
import stew/[endians2, byteutils]
import nimcrypto/[utils, sha2, hmac]
import libp2p/stream/[connection]
import libp2p/stream/[connection, streamseq]
import libp2p/peerid
import libp2p/peerinfo
import libp2p/protobuf/minprotobuf
import libp2p/utility
import libp2p/errors
import libp2p/crypto/[crypto, chacha20poly1305, curve25519]
import libp2p/crypto/[crypto, chacha20poly1305, curve25519, hkdf]
import libp2p/protocols/secure/secure
when defined(libp2p_dump):
@ -58,6 +59,63 @@ type
nonce*: ChaChaPolyNonce
ad*: seq[byte]
#Noise Handshakes
NoiseTokens* = enum
T_e = "e"
T_s = "s"
T_es = "es"
T_ee = "ee"
T_se = "se"
T_ss = "se"
T_psk = "psk"
T_none = ""
MessageDirection* = enum
D_r = "->"
D_l = "<-"
D_none = ""
HandshakePattern* = object
name*: string
pre_message_patterns*: seq[(MessageDirection, seq[NoiseTokens])]
message_patterns*: seq[(MessageDirection, seq[NoiseTokens])]
#Noise states
# https://noiseprotocol.org/noise.html#the-cipherstate-object
CipherState* = object
k: ChaChaPolyKey
n: uint64
# https://noiseprotocol.org/noise.html#the-symmetricstate-object
SymmetricState* = object
cs: CipherState
ck: ChaChaPolyKey
h: MDigest[256]
# https://noiseprotocol.org/noise.html#the-handshakestate-object
HandshakeState = object
s: KeyPair
e: KeyPair
rs: Curve25519Key
re: Curve25519Key
ss: SymmetricState
initiator: bool
handshake_pattern: HandshakePattern
msg_pattern_idx: uint8
psk: seq[byte]
HandshakeResult = object
cs1: CipherState
cs2: CipherState
rs: Curve25519Key
h: MDigest[256] #The handshake state for channel binding
NoiseState* = object
hs: HandshakeState
hr: HandshakeResult
NoiseError* = object of LPError
NoiseHandshakeError* = object of NoiseError
NoiseDecryptTagError* = object of NoiseError
@ -66,6 +124,228 @@ type
NoiseMalformedHandshake* = object of NoiseError
# Supported Noise Handshake Patterns
const
EmptyMessagePattern = @[(D_none, @[T_none])]
NoiseHandshakePatterns* = {
"K1K1": HandshakePattern(name: "Noise_K1K1_25519_ChaChaPoly_SHA256",
pre_message_patterns: @[(D_r, @[T_s]),
(D_l, @[T_s])],
message_patterns: @[(D_r, @[T_e]),
(D_l, @[T_e, T_ee, T_es]),
(D_r, @[T_se])]
),
"XK1": HandshakePattern(name: "Noise_XK1_25519_ChaChaPoly_SHA256",
pre_message_patterns: @[(D_l, @[T_s])],
message_patterns: @[(D_r, @[T_e]),
(D_l, @[T_e, T_ee, T_es]),
(D_r, @[T_s, T_se])]
),
"XX": HandshakePattern(name: "Noise_XX_25519_ChaChaPoly_SHA256",
pre_message_patterns: EmptyMessagePattern,
message_patterns: @[(D_r, @[T_e]),
(D_l, @[T_e, T_ee, T_s, T_es]),
(D_r, @[T_s, T_se])]
),
"XXpsk0": HandshakePattern(name: "Noise_XXpsk0_25519_ChaChaPoly_SHA256",
pre_message_patterns: EmptyMessagePattern,
message_patterns: @[(D_r, @[T_psk, T_e]),
(D_l, @[T_e, T_ee, T_s, T_es]),
(D_r, @[T_s, T_se])]
)
}.toTable()
PayloadV2ProtocolIDs* = {
"": 0.uint8,
"Noise_K1K1_25519_ChaChaPoly_SHA256": 10.uint8,
"Noise_XK1_25519_ChaChaPoly_SHA256": 11.uint8,
"Noise_XX_25519_ChaChaPoly_SHA256": 12.uint8,
"Noise_XXpsk0_25519_ChaChaPoly_SHA256": 13.uint8,
"ChaChaPoly": 30.uint8
}.toTable()
# Utility
#Printing Handshake Patterns
proc print*(self: HandshakePattern)
{.raises: [IOError].}=
try:
if self.name != "":
echo self.name, ":"
#We iterate over pre message patterns, if any
if self.pre_message_patterns != EmptyMessagePattern:
for pattern in self.pre_message_patterns:
stdout.write " ", pattern[0]
var first = true
for token in pattern[1]:
if first:
stdout.write " ", token
first = false
else:
stdout.write ", ", token
stdout.write "\n"
stdout.flushFile()
stdout.write " ...\n"
stdout.flushFile()
#We iterate over message patterns
for pattern in self.message_patterns:
stdout.write " ", pattern[0]
var first = true
for token in pattern[1]:
if first:
stdout.write " ", token
first = false
else:
stdout.write ", ", token
stdout.write "\n"
stdout.flushFile()
except:
echo "HandshakePattern malformed"
proc genKeyPair*(rng: var BrHmacDrbgContext): KeyPair =
result.privateKey = Curve25519Key.random(rng)
result.publicKey = result.privateKey.public()
proc hashProtocol(name: string): MDigest[256] =
# If protocol_name is less than or equal to HASHLEN bytes in length,
# sets h equal to protocol_name with zero bytes appended to make HASHLEN bytes.
# Otherwise sets h = HASH(protocol_name).
if name.len <= 32:
result.data[0..name.high] = name.toBytes
else:
result = sha256.digest(name)
proc dh(priv: Curve25519Key, pub: Curve25519Key): Curve25519Key =
result = pub
Curve25519.mul(result, priv)
# Cipherstate
proc hasKey(cs: CipherState): bool =
cs.k != EmptyKey
proc encrypt(
state: var CipherState,
data: var openArray[byte],
ad: openArray[byte]): ChaChaPolyTag
{.noinit, raises: [Defect, NoiseNonceMaxError].} =
var nonce: ChaChaPolyNonce
nonce[4..<12] = toBytesLE(state.n)
ChaChaPoly.encrypt(state.k, nonce, result, data, ad)
inc state.n
if state.n > NonceMax:
raise newException(NoiseNonceMaxError, "Noise max nonce value reached")
proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte]
{.raises: [Defect, NoiseNonceMaxError].} =
result = newSeqOfCap[byte](data.len + sizeof(ChaChaPolyTag))
result.add(data)
let tag = encrypt(state, result, ad)
result.add(tag)
trace "encryptWithAd",
tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1
proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte]
{.raises: [Defect, NoiseDecryptTagError, NoiseNonceMaxError].} =
var
tagIn = data.toOpenArray(data.len - ChaChaPolyTag.len, data.high).intoChaChaPolyTag
tagOut: ChaChaPolyTag
nonce: ChaChaPolyNonce
nonce[4..<12] = toBytesLE(state.n)
result = data[0..(data.high - ChaChaPolyTag.len)]
ChaChaPoly.decrypt(state.k, nonce, tagOut, result, ad)
trace "decryptWithAd", tagIn = tagIn.shortLog, tagOut = tagOut.shortLog, nonce = state.n
if tagIn != tagOut:
debug "decryptWithAd failed", data = shortLog(data)
raise newException(NoiseDecryptTagError, "decryptWithAd failed tag authentication.")
inc state.n
if state.n > NonceMax:
raise newException(NoiseNonceMaxError, "Noise max nonce value reached")
# Symmetricstate
proc init*(_: type[SymmetricState], hs_pattern: HandshakePattern): SymmetricState =
result.h = hs_pattern.name.hashProtocol
result.ck = result.h.data.intoChaChaPolyKey
result.cs = CipherState(k: EmptyKey)
proc mixKey(ss: var SymmetricState, ikm: ChaChaPolyKey) =
var
temp_keys: array[2, ChaChaPolyKey]
sha256.hkdf(ss.ck, ikm, [], temp_keys)
ss.ck = temp_keys[0]
ss.cs = CipherState(k: temp_keys[1])
trace "mixKey", key = ss.cs.k.shortLog
proc mixHash(ss: var SymmetricState, data: openArray[byte]) =
var ctx: sha256
ctx.init()
ctx.update(ss.h.data)
ctx.update(data)
ss.h = ctx.finish()
trace "mixHash", hash = ss.h.data.shortLog
# We might use this for other handshake patterns/tokens
proc mixKeyAndHash(ss: var SymmetricState, ikm: openArray[byte]) {.used.} =
var
temp_keys: array[3, ChaChaPolyKey]
sha256.hkdf(ss.ck, ikm, [], temp_keys)
ss.ck = temp_keys[0]
ss.mixHash(temp_keys[1])
ss.cs = CipherState(k: temp_keys[2])
proc encryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte]
{.raises: [Defect, NoiseNonceMaxError].} =
# according to spec if key is empty leave plaintext
if ss.cs.hasKey:
result = ss.cs.encryptWithAd(ss.h.data, data)
else:
result = @data
ss.mixHash(result)
proc decryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte]
{.raises: [Defect, NoiseDecryptTagError, NoiseNonceMaxError].} =
# according to spec if key is empty leave plaintext
if ss.cs.hasKey and data.len > ChaChaPolyTag.len:
result = ss.cs.decryptWithAd(ss.h.data, data)
else:
result = @data
ss.mixHash(data)
proc split(ss: var SymmetricState): tuple[cs1, cs2: CipherState] =
var
temp_keys: array[2, ChaChaPolyKey]
sha256.hkdf(ss.ck, [], [], temp_keys)
return (CipherState(k: temp_keys[0]), CipherState(k: temp_keys[1]))
# Handshake state
proc init*(_: type[HandshakeState], hs_pattern: HandshakePattern, psk: seq[byte] = @[]): HandshakeState =
# set to true only if startHandshake is called over the handshake state
result.initiator = false
result.handshake_pattern = hs_pattern
result.psk = psk
result.ss = SymmetricState.init(hs_pattern)
#################################################################
@ -105,13 +385,6 @@ proc randomChaChaPolyCipherState*(rng: var BrHmacDrbgContext): ChaChaPolyCipherS
#################################################################
# Utility
proc genKeyPair*(rng: var BrHmacDrbgContext): KeyPair =
result.privateKey = Curve25519Key.random(rng)
result.publicKey = result.privateKey.public()
# Public keys serializations/encryption
proc `==`(k1, k2: NoisePublicKey): bool =
@ -170,9 +443,11 @@ proc decryptNoisePublicKey*(cs: ChaChaPolyCipherState, noisePublicKey: NoisePubl
#################################################################
# Payload functions
# Payload V2 functions
type
PayloadV2* = object
protocol_id: uint8
@ -215,7 +490,6 @@ proc encodeV2*(self: PayloadV2): Option[seq[byte]] =
return none(seq[byte])
let transport_message_len = self.transport_message.len
#let transport_message_len_len = ceil(log(transport_message_len, 8)).int
var payload = newSeqOfCap[byte](1 + #self.protocol_id.len +
1 + #ser_handshake_message_len