mirror of https://github.com/waku-org/nwaku.git
feat(noise): add Noise Handshake State Machine and primitives
This commit is contained in:
parent
430f2ecd67
commit
5d63431f34
|
@ -1 +1 @@
|
||||||
Subproject commit 92cd608a5f47de1aa55861afa6dcc13bea4ae842
|
Subproject commit d442d84d221655ea25271b41bd2de546bafe4914
|
|
@ -7,17 +7,18 @@
|
||||||
|
|
||||||
{.push raises: [Defect].}
|
{.push raises: [Defect].}
|
||||||
|
|
||||||
import std/[options, tables]
|
import std/[oids, strformat, options, math, tables]
|
||||||
import chronos
|
import chronos
|
||||||
import chronicles
|
import chronicles
|
||||||
import bearssl
|
import bearssl
|
||||||
import strutils
|
import strutils
|
||||||
import stew/[results, endians2]
|
import stew/[results, endians2, byteutils]
|
||||||
import nimcrypto/[utils, sha2, hmac]
|
import nimcrypto/[utils, sha2, hmac]
|
||||||
|
|
||||||
import libp2p/utility
|
import libp2p/utility
|
||||||
import libp2p/errors
|
import libp2p/errors
|
||||||
import libp2p/crypto/[crypto, chacha20poly1305, curve25519]
|
import libp2p/crypto/[crypto, chacha20poly1305, curve25519, hkdf]
|
||||||
|
import libp2p/protocols/secure/secure
|
||||||
|
|
||||||
|
|
||||||
logScope:
|
logScope:
|
||||||
|
@ -72,6 +73,63 @@ type
|
||||||
handshakeMessage: seq[NoisePublicKey]
|
handshakeMessage: seq[NoisePublicKey]
|
||||||
transportMessage: seq[byte]
|
transportMessage: seq[byte]
|
||||||
|
|
||||||
|
#Noise Handshakes
|
||||||
|
|
||||||
|
NoiseTokens* = enum
|
||||||
|
T_e = "e"
|
||||||
|
T_s = "s"
|
||||||
|
T_es = "es"
|
||||||
|
T_ee = "ee"
|
||||||
|
T_se = "se"
|
||||||
|
T_ss = "se"
|
||||||
|
T_psk = "psk"
|
||||||
|
T_none = ""
|
||||||
|
|
||||||
|
MessageDirection* = enum
|
||||||
|
D_r = "->"
|
||||||
|
D_l = "<-"
|
||||||
|
D_none = ""
|
||||||
|
|
||||||
|
HandshakePattern* = object
|
||||||
|
name*: string
|
||||||
|
pre_message_patterns*: seq[(MessageDirection, seq[NoiseTokens])]
|
||||||
|
message_patterns*: seq[(MessageDirection, seq[NoiseTokens])]
|
||||||
|
|
||||||
|
#Noise states
|
||||||
|
|
||||||
|
# https://noiseprotocol.org/noise.html#the-cipherstate-object
|
||||||
|
CipherState* = object
|
||||||
|
k: ChaChaPolyKey
|
||||||
|
n: uint64
|
||||||
|
|
||||||
|
# https://noiseprotocol.org/noise.html#the-symmetricstate-object
|
||||||
|
SymmetricState* = object
|
||||||
|
cs: CipherState
|
||||||
|
ck: ChaChaPolyKey
|
||||||
|
h: MDigest[256]
|
||||||
|
|
||||||
|
# https://noiseprotocol.org/noise.html#the-handshakestate-object
|
||||||
|
HandshakeState = object
|
||||||
|
s: KeyPair
|
||||||
|
e: KeyPair
|
||||||
|
rs: Curve25519Key
|
||||||
|
re: Curve25519Key
|
||||||
|
ss: SymmetricState
|
||||||
|
initiator: bool
|
||||||
|
handshake_pattern: HandshakePattern
|
||||||
|
msg_pattern_idx: uint8
|
||||||
|
psk: seq[byte]
|
||||||
|
|
||||||
|
HandshakeResult = object
|
||||||
|
cs1: CipherState
|
||||||
|
cs2: CipherState
|
||||||
|
rs: Curve25519Key
|
||||||
|
h: MDigest[256] #The handshake state for channel binding
|
||||||
|
|
||||||
|
NoiseState* = object
|
||||||
|
hs: HandshakeState
|
||||||
|
hr: HandshakeResult
|
||||||
|
|
||||||
# Some useful error types
|
# Some useful error types
|
||||||
NoiseError* = object of LPError
|
NoiseError* = object of LPError
|
||||||
NoiseHandshakeError* = object of NoiseError
|
NoiseHandshakeError* = object of NoiseError
|
||||||
|
@ -81,6 +139,54 @@ type
|
||||||
NoisePublicKeyError* = object of NoiseError
|
NoisePublicKeyError* = object of NoiseError
|
||||||
NoiseMalformedHandshake* = object of NoiseError
|
NoiseMalformedHandshake* = object of NoiseError
|
||||||
|
|
||||||
|
# Supported Noise Handshake Patterns
|
||||||
|
const
|
||||||
|
EmptyMessagePattern = @[(D_none, @[T_none])]
|
||||||
|
|
||||||
|
NoiseHandshakePatterns* = {
|
||||||
|
|
||||||
|
"K1K1": HandshakePattern(name: "Noise_K1K1_25519_ChaChaPoly_SHA256",
|
||||||
|
pre_message_patterns: @[(D_r, @[T_s]),
|
||||||
|
(D_l, @[T_s])],
|
||||||
|
message_patterns: @[(D_r, @[T_e]),
|
||||||
|
(D_l, @[T_e, T_ee, T_es]),
|
||||||
|
(D_r, @[T_se])]
|
||||||
|
),
|
||||||
|
|
||||||
|
"XK1": HandshakePattern(name: "Noise_XK1_25519_ChaChaPoly_SHA256",
|
||||||
|
pre_message_patterns: @[(D_l, @[T_s])],
|
||||||
|
message_patterns: @[(D_r, @[T_e]),
|
||||||
|
(D_l, @[T_e, T_ee, T_es]),
|
||||||
|
(D_r, @[T_s, T_se])]
|
||||||
|
),
|
||||||
|
|
||||||
|
"XX": HandshakePattern(name: "Noise_XX_25519_ChaChaPoly_SHA256",
|
||||||
|
pre_message_patterns: EmptyMessagePattern,
|
||||||
|
message_patterns: @[(D_r, @[T_e]),
|
||||||
|
(D_l, @[T_e, T_ee, T_s, T_es]),
|
||||||
|
(D_r, @[T_s, T_se])]
|
||||||
|
),
|
||||||
|
|
||||||
|
"XXpsk0": HandshakePattern(name: "Noise_XXpsk0_25519_ChaChaPoly_SHA256",
|
||||||
|
pre_message_patterns: EmptyMessagePattern,
|
||||||
|
message_patterns: @[(D_r, @[T_psk, T_e]),
|
||||||
|
(D_l, @[T_e, T_ee, T_s, T_es]),
|
||||||
|
(D_r, @[T_s, T_se])]
|
||||||
|
)
|
||||||
|
|
||||||
|
}.toTable()
|
||||||
|
|
||||||
|
|
||||||
|
PayloadV2ProtocolIDs* = {
|
||||||
|
|
||||||
|
"": 0.uint8,
|
||||||
|
"Noise_K1K1_25519_ChaChaPoly_SHA256": 10.uint8,
|
||||||
|
"Noise_XK1_25519_ChaChaPoly_SHA256": 11.uint8,
|
||||||
|
"Noise_XX_25519_ChaChaPoly_SHA256": 12.uint8,
|
||||||
|
"Noise_XXpsk0_25519_ChaChaPoly_SHA256": 13.uint8,
|
||||||
|
"ChaChaPoly": 30.uint8
|
||||||
|
|
||||||
|
}.toTable()
|
||||||
|
|
||||||
#################################################################
|
#################################################################
|
||||||
|
|
||||||
|
@ -99,6 +205,172 @@ proc genKeyPair*(rng: var BrHmacDrbgContext): KeyPair =
|
||||||
keyPair.publicKey = keyPair.privateKey.public()
|
keyPair.publicKey = keyPair.privateKey.public()
|
||||||
return keyPair
|
return keyPair
|
||||||
|
|
||||||
|
#Printing Handshake Patterns
|
||||||
|
proc print*(self: HandshakePattern)
|
||||||
|
{.raises: [IOError].}=
|
||||||
|
try:
|
||||||
|
if self.name != "":
|
||||||
|
echo self.name, ":"
|
||||||
|
#We iterate over pre message patterns, if any
|
||||||
|
if self.pre_message_patterns != EmptyMessagePattern:
|
||||||
|
for pattern in self.pre_message_patterns:
|
||||||
|
stdout.write " ", pattern[0]
|
||||||
|
var first = true
|
||||||
|
for token in pattern[1]:
|
||||||
|
if first:
|
||||||
|
stdout.write " ", token
|
||||||
|
first = false
|
||||||
|
else:
|
||||||
|
stdout.write ", ", token
|
||||||
|
stdout.write "\n"
|
||||||
|
stdout.flushFile()
|
||||||
|
stdout.write " ...\n"
|
||||||
|
stdout.flushFile()
|
||||||
|
#We iterate over message patterns
|
||||||
|
for pattern in self.message_patterns:
|
||||||
|
stdout.write " ", pattern[0]
|
||||||
|
var first = true
|
||||||
|
for token in pattern[1]:
|
||||||
|
if first:
|
||||||
|
stdout.write " ", token
|
||||||
|
first = false
|
||||||
|
else:
|
||||||
|
stdout.write ", ", token
|
||||||
|
stdout.write "\n"
|
||||||
|
stdout.flushFile()
|
||||||
|
except:
|
||||||
|
echo "HandshakePattern malformed"
|
||||||
|
|
||||||
|
|
||||||
|
proc hashProtocol(name: string): MDigest[256] =
|
||||||
|
# If protocol_name is less than or equal to HASHLEN bytes in length,
|
||||||
|
# sets h equal to protocol_name with zero bytes appended to make HASHLEN bytes.
|
||||||
|
# Otherwise sets h = HASH(protocol_name).
|
||||||
|
|
||||||
|
if name.len <= 32:
|
||||||
|
result.data[0..name.high] = name.toBytes
|
||||||
|
else:
|
||||||
|
result = sha256.digest(name)
|
||||||
|
|
||||||
|
proc dh(priv: Curve25519Key, pub: Curve25519Key): Curve25519Key =
|
||||||
|
result = pub
|
||||||
|
Curve25519.mul(result, priv)
|
||||||
|
|
||||||
|
# Cipherstate
|
||||||
|
|
||||||
|
proc hasKey(cs: CipherState): bool =
|
||||||
|
cs.k != EmptyKey
|
||||||
|
|
||||||
|
proc encrypt(
|
||||||
|
state: var CipherState,
|
||||||
|
data: var openArray[byte],
|
||||||
|
ad: openArray[byte]): ChaChaPolyTag
|
||||||
|
{.noinit, raises: [Defect, NoiseNonceMaxError].} =
|
||||||
|
|
||||||
|
var nonce: ChaChaPolyNonce
|
||||||
|
nonce[4..<12] = toBytesLE(state.n)
|
||||||
|
|
||||||
|
ChaChaPoly.encrypt(state.k, nonce, result, data, ad)
|
||||||
|
|
||||||
|
inc state.n
|
||||||
|
if state.n > NonceMax:
|
||||||
|
raise newException(NoiseNonceMaxError, "Noise max nonce value reached")
|
||||||
|
|
||||||
|
proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte]
|
||||||
|
{.raises: [Defect, NoiseNonceMaxError].} =
|
||||||
|
result = newSeqOfCap[byte](data.len + sizeof(ChaChaPolyTag))
|
||||||
|
result.add(data)
|
||||||
|
|
||||||
|
let tag = encrypt(state, result, ad)
|
||||||
|
|
||||||
|
result.add(tag)
|
||||||
|
|
||||||
|
trace "encryptWithAd",
|
||||||
|
tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1
|
||||||
|
|
||||||
|
proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte]
|
||||||
|
{.raises: [Defect, NoiseDecryptTagError, NoiseNonceMaxError].} =
|
||||||
|
var
|
||||||
|
tagIn = data.toOpenArray(data.len - ChaChaPolyTag.len, data.high).intoChaChaPolyTag
|
||||||
|
tagOut: ChaChaPolyTag
|
||||||
|
nonce: ChaChaPolyNonce
|
||||||
|
nonce[4..<12] = toBytesLE(state.n)
|
||||||
|
result = data[0..(data.high - ChaChaPolyTag.len)]
|
||||||
|
ChaChaPoly.decrypt(state.k, nonce, tagOut, result, ad)
|
||||||
|
trace "decryptWithAd", tagIn = tagIn.shortLog, tagOut = tagOut.shortLog, nonce = state.n
|
||||||
|
if tagIn != tagOut:
|
||||||
|
debug "decryptWithAd failed", data = shortLog(data)
|
||||||
|
raise newException(NoiseDecryptTagError, "decryptWithAd failed tag authentication.")
|
||||||
|
inc state.n
|
||||||
|
if state.n > NonceMax:
|
||||||
|
raise newException(NoiseNonceMaxError, "Noise max nonce value reached")
|
||||||
|
|
||||||
|
# Symmetricstate
|
||||||
|
|
||||||
|
proc init*(_: type[SymmetricState], hs_pattern: HandshakePattern): SymmetricState =
|
||||||
|
result.h = hs_pattern.name.hashProtocol
|
||||||
|
result.ck = result.h.data.intoChaChaPolyKey
|
||||||
|
result.cs = CipherState(k: EmptyKey)
|
||||||
|
|
||||||
|
proc mixKey(ss: var SymmetricState, ikm: ChaChaPolyKey) =
|
||||||
|
var
|
||||||
|
temp_keys: array[2, ChaChaPolyKey]
|
||||||
|
sha256.hkdf(ss.ck, ikm, [], temp_keys)
|
||||||
|
ss.ck = temp_keys[0]
|
||||||
|
ss.cs = CipherState(k: temp_keys[1])
|
||||||
|
trace "mixKey", key = ss.cs.k.shortLog
|
||||||
|
|
||||||
|
proc mixHash(ss: var SymmetricState, data: openArray[byte]) =
|
||||||
|
var ctx: sha256
|
||||||
|
ctx.init()
|
||||||
|
ctx.update(ss.h.data)
|
||||||
|
ctx.update(data)
|
||||||
|
ss.h = ctx.finish()
|
||||||
|
trace "mixHash", hash = ss.h.data.shortLog
|
||||||
|
|
||||||
|
# We might use this for other handshake patterns/tokens
|
||||||
|
proc mixKeyAndHash(ss: var SymmetricState, ikm: openArray[byte]) {.used.} =
|
||||||
|
var
|
||||||
|
temp_keys: array[3, ChaChaPolyKey]
|
||||||
|
sha256.hkdf(ss.ck, ikm, [], temp_keys)
|
||||||
|
ss.ck = temp_keys[0]
|
||||||
|
ss.mixHash(temp_keys[1])
|
||||||
|
ss.cs = CipherState(k: temp_keys[2])
|
||||||
|
|
||||||
|
proc encryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte]
|
||||||
|
{.raises: [Defect, NoiseNonceMaxError].} =
|
||||||
|
# according to spec if key is empty leave plaintext
|
||||||
|
if ss.cs.hasKey:
|
||||||
|
result = ss.cs.encryptWithAd(ss.h.data, data)
|
||||||
|
else:
|
||||||
|
result = @data
|
||||||
|
ss.mixHash(result)
|
||||||
|
|
||||||
|
proc decryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte]
|
||||||
|
{.raises: [Defect, NoiseDecryptTagError, NoiseNonceMaxError].} =
|
||||||
|
# according to spec if key is empty leave plaintext
|
||||||
|
if ss.cs.hasKey and data.len > ChaChaPolyTag.len:
|
||||||
|
result = ss.cs.decryptWithAd(ss.h.data, data)
|
||||||
|
else:
|
||||||
|
result = @data
|
||||||
|
ss.mixHash(data)
|
||||||
|
|
||||||
|
proc split(ss: var SymmetricState): tuple[cs1, cs2: CipherState] =
|
||||||
|
var
|
||||||
|
temp_keys: array[2, ChaChaPolyKey]
|
||||||
|
sha256.hkdf(ss.ck, [], [], temp_keys)
|
||||||
|
return (CipherState(k: temp_keys[0]), CipherState(k: temp_keys[1]))
|
||||||
|
|
||||||
|
|
||||||
|
# Handshake state
|
||||||
|
|
||||||
|
proc init*(_: type[HandshakeState], hs_pattern: HandshakePattern, psk: seq[byte] = @[]): HandshakeState =
|
||||||
|
# set to true only if startHandshake is called over the handshake state
|
||||||
|
result.initiator = false
|
||||||
|
result.handshake_pattern = hs_pattern
|
||||||
|
result.psk = psk
|
||||||
|
result.ss = SymmetricState.init(hs_pattern)
|
||||||
|
|
||||||
|
|
||||||
#################################################################
|
#################################################################
|
||||||
|
|
||||||
|
@ -256,7 +528,6 @@ proc decryptNoisePublicKey*(cs: ChaChaPolyCipherState, noisePublicKey: NoisePubl
|
||||||
decryptedNoisePublicKey = noisePublicKey
|
decryptedNoisePublicKey = noisePublicKey
|
||||||
return decryptedNoisePublicKey
|
return decryptedNoisePublicKey
|
||||||
|
|
||||||
|
|
||||||
#################################################################
|
#################################################################
|
||||||
|
|
||||||
# Payload encoding/decoding procedures
|
# Payload encoding/decoding procedures
|
||||||
|
|
Loading…
Reference in New Issue