noise updates (#255)

* clear secrets explicitly
* simplify keygen
* avoid some trivial memory allocations
* fix little endian encoding of nonce
This commit is contained in:
Jacek Sieka 2020-07-09 10:53:19 +02:00 committed by GitHub
parent 4e12d0d97a
commit 45c089ff0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 173 additions and 146 deletions

View File

@ -105,10 +105,13 @@ proc mulgen*(_: type[Curve25519], dst: var Curve25519Key, point: Curve25519Key)
proc public*(private: Curve25519Key): Curve25519Key = proc public*(private: Curve25519Key): Curve25519Key =
Curve25519.mulgen(result, private) Curve25519.mulgen(result, private)
proc random*(_: type[Curve25519Key], rng: var BrHmacDrbgContext): Result[Curve25519Key, Curve25519Error] = proc random*(_: type[Curve25519Key], rng: var BrHmacDrbgContext): Curve25519Key =
var res: Curve25519Key var res: Curve25519Key
let defaultBrEc = brEcGetDefault() let defaultBrEc = brEcGetDefault()
if brEcKeygen(addr rng.vtable, defaultBrEc, nil, addr res[0], EC_curve25519) != Curve25519KeySize: let len = brEcKeygen(
err(Curver25519GenError) addr rng.vtable, defaultBrEc, nil, addr res[0], EC_curve25519)
else: # Per bearssl documentation, the keygen only fails if the curve is
ok(res) # unrecognised -
doAssert len == Curve25519KeySize, "Could not generate curve"
res

View File

@ -39,11 +39,11 @@ const
type type
EcPrivateKey* = ref object EcPrivateKey* = ref object
buffer*: seq[byte] buffer*: array[BR_EC_KBUF_PRIV_MAX_SIZE, byte]
key*: BrEcPrivateKey key*: BrEcPrivateKey
EcPublicKey* = ref object EcPublicKey* = ref object
buffer*: seq[byte] buffer*: array[BR_EC_KBUF_PUB_MAX_SIZE, byte]
key*: BrEcPublicKey key*: BrEcPublicKey
EcKeyPair* = object EcKeyPair* = object
@ -237,7 +237,6 @@ proc random*(
## secp521r1). ## secp521r1).
var ecimp = brEcGetDefault() var ecimp = brEcGetDefault()
var res = new EcPrivateKey var res = new EcPrivateKey
res.buffer = newSeq[byte](BR_EC_KBUF_PRIV_MAX_SIZE)
if brEcKeygen(addr rng.vtable, ecimp, if brEcKeygen(addr rng.vtable, ecimp,
addr res.key, addr res.buffer[0], addr res.key, addr res.buffer[0],
cast[cint](kind)) == 0: cast[cint](kind)) == 0:
@ -254,7 +253,6 @@ proc getKey*(seckey: EcPrivateKey): EcResult[EcPublicKey] =
if seckey.key.curve in EcSupportedCurvesCint: if seckey.key.curve in EcSupportedCurvesCint:
var length = getPublicKeyLength(cast[EcCurveKind](seckey.key.curve)) var length = getPublicKeyLength(cast[EcCurveKind](seckey.key.curve))
var res = new EcPublicKey var res = new EcPublicKey
res.buffer = newSeq[byte](length)
if brEcComputePublicKey(ecimp, addr res.key, if brEcComputePublicKey(ecimp, addr res.key,
addr res.buffer[0], unsafeAddr seckey.key) == 0: addr res.buffer[0], unsafeAddr seckey.key) == 0:
err(EcKeyIncorrectError) err(EcKeyIncorrectError)
@ -621,7 +619,6 @@ proc init*(key: var EcPrivateKey, data: openarray[byte]): Result[void, Asn1Error
if checkScalar(raw.toOpenArray(), curve) == 1'u32: if checkScalar(raw.toOpenArray(), curve) == 1'u32:
key = new EcPrivateKey key = new EcPrivateKey
key.buffer = newSeq[byte](raw.length)
copyMem(addr key.buffer[0], addr raw.buffer[raw.offset], raw.length) copyMem(addr key.buffer[0], addr raw.buffer[raw.offset], raw.length)
key.key.x = cast[ptr cuchar](addr key.buffer[0]) key.key.x = cast[ptr cuchar](addr key.buffer[0])
key.key.xlen = raw.length key.key.xlen = raw.length
@ -681,7 +678,6 @@ proc init*(pubkey: var EcPublicKey, data: openarray[byte]): Result[void, Asn1Err
if checkPublic(raw.toOpenArray(), curve) != 0: if checkPublic(raw.toOpenArray(), curve) != 0:
pubkey = new EcPublicKey pubkey = new EcPublicKey
pubkey.buffer = newSeq[byte](raw.length)
copyMem(addr pubkey.buffer[0], addr raw.buffer[raw.offset], raw.length) copyMem(addr pubkey.buffer[0], addr raw.buffer[raw.offset], raw.length)
pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0]) pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0])
pubkey.key.qlen = raw.length pubkey.key.qlen = raw.length
@ -769,7 +765,6 @@ proc initRaw*(key: var EcPrivateKey, data: openarray[byte]): bool =
if checkScalar(data, curve) == 1'u32: if checkScalar(data, curve) == 1'u32:
let length = len(data) let length = len(data)
key = new EcPrivateKey key = new EcPrivateKey
key.buffer = newSeq[byte](length)
copyMem(addr key.buffer[0], unsafeAddr data[0], length) copyMem(addr key.buffer[0], unsafeAddr data[0], length)
key.key.x = cast[ptr cuchar](addr key.buffer[0]) key.key.x = cast[ptr cuchar](addr key.buffer[0])
key.key.xlen = length key.key.xlen = length
@ -801,7 +796,6 @@ proc initRaw*(pubkey: var EcPublicKey, data: openarray[byte]): bool =
if checkPublic(data, curve) != 0: if checkPublic(data, curve) != 0:
let length = len(data) let length = len(data)
pubkey = new EcPublicKey pubkey = new EcPublicKey
pubkey.buffer = newSeq[byte](length)
copyMem(addr pubkey.buffer[0], unsafeAddr data[0], length) copyMem(addr pubkey.buffer[0], unsafeAddr data[0], length)
pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0]) pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0])
pubkey.key.qlen = length pubkey.key.qlen = length

View File

@ -12,15 +12,13 @@ import chronicles
import bearssl import bearssl
import stew/[endians2, byteutils] import stew/[endians2, byteutils]
import nimcrypto/[utils, sha2, hmac] import nimcrypto/[utils, sha2, hmac]
import ../../stream/lpstream import ../../stream/[connection, streamseq]
import ../../peerid import ../../peerid
import ../../peerinfo import ../../peerinfo
import ../../protobuf/minprotobuf import ../../protobuf/minprotobuf
import ../../utility import ../../utility
import ../../stream/lpstream
import secure, import secure,
../../crypto/[crypto, chacha20poly1305, curve25519, hkdf], ../../crypto/[crypto, chacha20poly1305, curve25519, hkdf]
../../stream/bufferstream
logScope: logScope:
topics = "noise" topics = "noise"
@ -34,7 +32,7 @@ const
ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256" ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256"
# Empty is a special value which indicates k has not yet been initialized. # Empty is a special value which indicates k has not yet been initialized.
EmptyKey: ChaChaPolyKey = [0.byte, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] EmptyKey = default(ChaChaPolyKey)
NonceMax = uint64.high - 1 # max is reserved NonceMax = uint64.high - 1 # max is reserved
NoiseSize = 32 NoiseSize = 32
MaxPlainSize = int(uint16.high - NoiseSize - ChaChaPolyTag.len) MaxPlainSize = int(uint16.high - NoiseSize - ChaChaPolyTag.len)
@ -72,7 +70,7 @@ type
Noise* = ref object of Secure Noise* = ref object of Secure
rng: ref BrHmacDrbgContext rng: ref BrHmacDrbgContext
localPrivateKey: PrivateKey localPrivateKey: PrivateKey
localPublicKey: PublicKey localPublicKey: seq[byte]
noiseKeys: KeyPair noiseKeys: KeyPair
commonPrologue: seq[byte] commonPrologue: seq[byte]
outgoing: bool outgoing: bool
@ -89,7 +87,7 @@ type
# Utility # Utility
proc genKeyPair(rng: var BrHmacDrbgContext): KeyPair = proc genKeyPair(rng: var BrHmacDrbgContext): KeyPair =
result.privateKey = Curve25519Key.random(rng).tryGet() result.privateKey = Curve25519Key.random(rng)
result.publicKey = result.privateKey.public() result.publicKey = result.privateKey.public()
proc hashProtocol(name: string): MDigest[256] = proc hashProtocol(name: string): MDigest[256] =
@ -110,12 +108,11 @@ proc dh(priv: Curve25519Key, pub: Curve25519Key): Curve25519Key =
proc hasKey(cs: CipherState): bool = proc hasKey(cs: CipherState): bool =
cs.k != EmptyKey cs.k != EmptyKey
proc encryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte] = proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
var var
tag: ChaChaPolyTag tag: ChaChaPolyTag
nonce: ChaChaPolyNonce nonce: ChaChaPolyNonce
np = cast[ptr uint64](addr nonce[4]) nonce[4..<12] = toBytesLE(state.n)
np[] = state.n
result = @data result = @data
ChaChaPoly.encrypt(state.k, nonce, tag, result, ad) ChaChaPoly.encrypt(state.k, nonce, tag, result, ad)
inc state.n inc state.n
@ -124,13 +121,12 @@ proc encryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte]
result &= tag result &= tag
trace "encryptWithAd", tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1 trace "encryptWithAd", tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1
proc decryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte] = proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
var var
tagIn = data[^ChaChaPolyTag.len..data.high].intoChaChaPolyTag tagIn = data.toOpenArray(data.len - ChaChaPolyTag.len, data.high).intoChaChaPolyTag
tagOut = tagIn tagOut: ChaChaPolyTag
nonce: ChaChaPolyNonce nonce: ChaChaPolyNonce
np = cast[ptr uint64](addr nonce[4]) nonce[4..<12] = toBytesLE(state.n)
np[] = state.n
result = data[0..(data.high - ChaChaPolyTag.len)] result = data[0..(data.high - ChaChaPolyTag.len)]
ChaChaPoly.decrypt(state.k, nonce, tagOut, result, ad) ChaChaPoly.decrypt(state.k, nonce, tagOut, result, ad)
trace "decryptWithAd", tagIn = tagIn.shortLog, tagOut = tagOut.shortLog, nonce = state.n trace "decryptWithAd", tagIn = tagIn.shortLog, tagOut = tagOut.shortLog, nonce = state.n
@ -156,7 +152,7 @@ proc mixKey(ss: var SymmetricState, ikm: ChaChaPolyKey) =
ss.cs = CipherState(k: temp_keys[1]) ss.cs = CipherState(k: temp_keys[1])
trace "mixKey", key = ss.cs.k.shortLog trace "mixKey", key = ss.cs.k.shortLog
proc mixHash(ss: var SymmetricState; data: openarray[byte]) = proc mixHash(ss: var SymmetricState; data: openArray[byte]) =
var ctx: sha256 var ctx: sha256
ctx.init() ctx.init()
ctx.update(ss.h.data) ctx.update(ss.h.data)
@ -165,7 +161,7 @@ proc mixHash(ss: var SymmetricState; data: openarray[byte]) =
trace "mixHash", hash = ss.h.data.shortLog trace "mixHash", hash = ss.h.data.shortLog
# We might use this for other handshake patterns/tokens # We might use this for other handshake patterns/tokens
proc mixKeyAndHash(ss: var SymmetricState; ikm: openarray[byte]) {.used.} = proc mixKeyAndHash(ss: var SymmetricState; ikm: openArray[byte]) {.used.} =
var var
temp_keys: array[3, ChaChaPolyKey] temp_keys: array[3, ChaChaPolyKey]
sha256.hkdf(ss.ck, ikm, [], temp_keys) sha256.hkdf(ss.ck, ikm, [], temp_keys)
@ -173,7 +169,7 @@ proc mixKeyAndHash(ss: var SymmetricState; ikm: openarray[byte]) {.used.} =
ss.mixHash(temp_keys[1]) ss.mixHash(temp_keys[1])
ss.cs = CipherState(k: temp_keys[2]) ss.cs = CipherState(k: temp_keys[2])
proc encryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] = proc encryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte] =
# according to spec if key is empty leave plaintext # according to spec if key is empty leave plaintext
if ss.cs.hasKey: if ss.cs.hasKey:
result = ss.cs.encryptWithAd(ss.h.data, data) result = ss.cs.encryptWithAd(ss.h.data, data)
@ -181,7 +177,7 @@ proc encryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] =
result = @data result = @data
ss.mixHash(result) ss.mixHash(result)
proc decryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] = proc decryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte] =
# according to spec if key is empty leave plaintext # according to spec if key is empty leave plaintext
if ss.cs.hasKey: if ss.cs.hasKey:
result = ss.cs.decryptWithAd(ss.h.data, data) result = ss.cs.decryptWithAd(ss.h.data, data)
@ -202,13 +198,13 @@ template write_e: untyped =
trace "noise write e" trace "noise write e"
# Sets e (which must be empty) to GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key). # Sets e (which must be empty) to GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key).
hs.e = genKeyPair(p.rng[]) hs.e = genKeyPair(p.rng[])
msg &= hs.e.publicKey msg.add hs.e.publicKey
hs.ss.mixHash(hs.e.publicKey) hs.ss.mixHash(hs.e.publicKey)
template write_s: untyped = template write_s: untyped =
trace "noise write s" trace "noise write s"
# Appends EncryptAndHash(s.public_key) to the buffer. # Appends EncryptAndHash(s.public_key) to the buffer.
msg &= hs.ss.encryptAndHash(hs.s.publicKey) msg.add hs.ss.encryptAndHash(hs.s.publicKey)
template dh_ee: untyped = template dh_ee: untyped =
trace "noise dh ee" trace "noise dh ee"
@ -244,8 +240,8 @@ template read_e: untyped =
raise newException(NoiseHandshakeError, "Noise E, expected more data") raise newException(NoiseHandshakeError, "Noise E, expected more data")
# Sets re (which must be empty) to the next DHLEN bytes from the message. Calls MixHash(re.public_key). # Sets re (which must be empty) to the next DHLEN bytes from the message. Calls MixHash(re.public_key).
hs.re[0..Curve25519Key.high] = msg[0..Curve25519Key.high] hs.re[0..Curve25519Key.high] = msg.toOpenArray(0, Curve25519Key.high)
msg = msg[Curve25519Key.len..msg.high] msg.consume(Curve25519Key.len)
hs.ss.mixHash(hs.re) hs.ss.mixHash(hs.re)
template read_s: untyped = template read_s: untyped =
@ -253,30 +249,33 @@ template read_s: untyped =
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes otherwise. # Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes otherwise.
# Sets rs (which must be empty) to DecryptAndHash(temp). # Sets rs (which must be empty) to DecryptAndHash(temp).
let let
temp = rsLen =
if hs.ss.cs.hasKey: if hs.ss.cs.hasKey:
if msg.len < Curve25519Key.len + ChaChaPolyTag.len: if msg.len < Curve25519Key.len + ChaChaPolyTag.len:
raise newException(NoiseHandshakeError, "Noise S, expected more data") raise newException(NoiseHandshakeError, "Noise S, expected more data")
msg[0..Curve25519Key.high + ChaChaPolyTag.len] Curve25519Key.len + ChaChaPolyTag.len
else: else:
if msg.len < Curve25519Key.len: if msg.len < Curve25519Key.len:
raise newException(NoiseHandshakeError, "Noise S, expected more data") raise newException(NoiseHandshakeError, "Noise S, expected more data")
msg[0..Curve25519Key.high] Curve25519Key.len
msg = msg[temp.len..msg.high] hs.rs[0..Curve25519Key.high] =
let plain = hs.ss.decryptAndHash(temp) hs.ss.decryptAndHash(msg.toOpenArray(0, rsLen - 1))
hs.rs[0..Curve25519Key.high] = plain
msg.consume(rsLen)
proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} = proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} =
var besize: array[2, byte] var besize: array[2, byte]
await sconn.readExactly(addr besize[0], besize.len) await sconn.readExactly(addr besize[0], besize.len)
let size = uint16.fromBytesBE(besize).int let size = uint16.fromBytesBE(besize).int
trace "receiveHSMessage", size trace "receiveHSMessage", size
if size == 0:
return
var buffer = newSeq[byte](size) var buffer = newSeq[byte](size)
if buffer.len > 0: await sconn.readExactly(addr buffer[0], buffer.len)
await sconn.readExactly(addr buffer[0], buffer.len)
return buffer return buffer
proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} = proc sendHSMessage(sconn: Connection; buf: openArray[byte]): Future[void] =
var var
lesize = buf.len.uint16 lesize = buf.len.uint16
besize = lesize.toBytesBE besize = lesize.toBytesBE
@ -284,97 +283,106 @@ proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} =
trace "sendHSMessage", size = lesize trace "sendHSMessage", size = lesize
outbuf &= besize outbuf &= besize
outbuf &= buf outbuf &= buf
await sconn.write(outbuf) sconn.write(outbuf)
proc handshakeXXOutbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Future[HandshakeResult] {.async.} = proc handshakeXXOutbound(
p: Noise, conn: Connection,
p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} =
const initiator = true const initiator = true
var var
hs = HandshakeState.init() hs = HandshakeState.init()
p2psecret = p2pProof.buffer
hs.ss.mixHash(p.commonPrologue) try:
hs.s = p.noiseKeys
# -> e hs.ss.mixHash(p.commonPrologue)
var msg: seq[byte] hs.s = p.noiseKeys
write_e() # -> e
var msg: StreamSeq
# IK might use this btw! write_e()
msg &= hs.ss.encryptAndHash(@[])
await conn.sendHSMessage(msg) # IK might use this btw!
msg.add hs.ss.encryptAndHash([])
# <- e, ee, s, es await conn.sendHSMessage(msg.data)
msg = await conn.receiveHSMessage() # <- e, ee, s, es
read_e() msg.assign(await conn.receiveHSMessage())
dh_ee()
read_s()
dh_es()
let remoteP2psecret = hs.ss.decryptAndHash(msg) read_e()
dh_ee()
read_s()
dh_es()
# -> s, se let remoteP2psecret = hs.ss.decryptAndHash(msg.data)
msg.clear()
msg.setLen(0) # -> s, se
write_s() write_s()
dh_se() dh_se()
# last payload must follow the ecrypted way of sending # last payload must follow the encrypted way of sending
msg &= hs.ss.encryptAndHash(p2psecret) msg.add hs.ss.encryptAndHash(p2psecret)
await conn.sendHSMessage(msg) await conn.sendHSMessage(msg.data)
let (cs1, cs2) = hs.ss.split() let (cs1, cs2) = hs.ss.split()
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs) return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
finally:
burnMem(hs)
proc handshakeXXInbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Future[HandshakeResult] {.async.} = proc handshakeXXInbound(
p: Noise, conn: Connection,
p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} =
const initiator = false const initiator = false
var var
hs = HandshakeState.init() hs = HandshakeState.init()
p2psecret = p2pProof.buffer
hs.ss.mixHash(p.commonPrologue) try:
hs.s = p.noiseKeys hs.ss.mixHash(p.commonPrologue)
hs.s = p.noiseKeys
# -> e # -> e
var msg = await conn.receiveHSMessage() var msg: StreamSeq
msg.add(await conn.receiveHSMessage())
read_e() read_e()
# we might use this early data one day, keeping it here for clarity # we might use this early data one day, keeping it here for clarity
let earlyData {.used.} = hs.ss.decryptAndHash(msg) let earlyData {.used.} = hs.ss.decryptAndHash(msg.data)
# <- e, ee, s, es # <- e, ee, s, es
msg.setLen(0) msg.consume(msg.len)
write_e() write_e()
dh_ee() dh_ee()
write_s() write_s()
dh_es() dh_es()
msg &= hs.ss.encryptAndHash(p2psecret) msg.add hs.ss.encryptAndHash(p2psecret)
await conn.sendHSMessage(msg) await conn.sendHSMessage(msg.data)
msg.clear()
# -> s, se # -> s, se
msg = await conn.receiveHSMessage() msg.add(await conn.receiveHSMessage())
read_s() read_s()
dh_se() dh_se()
let remoteP2psecret = hs.ss.decryptAndHash(msg) let
remoteP2psecret = hs.ss.decryptAndHash(msg.data)
let (cs1, cs2) = hs.ss.split() (cs1, cs2) = hs.ss.split()
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs) return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
finally:
burnMem(hs)
method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} = method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
while true: # Discard 0-length payloads while true: # Discard 0-length payloads
@ -399,7 +407,8 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.
while left > 0: while left > 0:
let let
chunkSize = if left > MaxPlainSize: MaxPlainSize else: left chunkSize = if left > MaxPlainSize: MaxPlainSize else: left
cipher = sconn.writeCs.encryptWithAd([], message.toOpenArray(offset, offset + chunkSize - 1)) cipher = sconn.writeCs.encryptWithAd(
[], message.toOpenArray(offset, offset + chunkSize - 1))
left = left - chunkSize left = left - chunkSize
offset = offset + chunkSize offset = offset + chunkSize
var var
@ -421,65 +430,75 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
var var
libp2pProof = initProtoBuffer() libp2pProof = initProtoBuffer()
libp2pProof.write(initProtoField(1, p.localPublicKey.getBytes.tryGet())) libp2pProof.write(initProtoField(1, p.localPublicKey))
libp2pProof.write(initProtoField(2, signedPayload.getBytes())) libp2pProof.write(initProtoField(2, signedPayload.getBytes()))
# data field also there but not used! # data field also there but not used!
libp2pProof.finish() libp2pProof.finish()
let handshakeRes = var handshakeRes =
if initiator: if initiator:
await handshakeXXOutbound(p, conn, libp2pProof) await handshakeXXOutbound(p, conn, libp2pProof.buffer)
else: else:
await handshakeXXInbound(p, conn, libp2pProof) await handshakeXXInbound(p, conn, libp2pProof.buffer)
var var secure = try:
remoteProof = initProtoBuffer(handshakeRes.remoteP2psecret) var
remotePubKey: PublicKey remoteProof = initProtoBuffer(handshakeRes.remoteP2psecret)
remotePubKeyBytes: seq[byte] remotePubKey: PublicKey
remoteSig: Signature remotePubKeyBytes: seq[byte]
remoteSigBytes: seq[byte] remoteSig: Signature
remoteSigBytes: seq[byte]
if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0: if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0:
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if remoteProof.getLengthValue(2, remoteSigBytes) <= 0: if remoteProof.getLengthValue(2, remoteSigBytes) <= 0:
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if not remotePubKey.init(remotePubKeyBytes): if not remotePubKey.init(remotePubKeyBytes):
raise newException(NoiseHandshakeError, "Failed to decode remote public key. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") raise newException(NoiseHandshakeError, "Failed to decode remote public key. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if not remoteSig.init(remoteSigBytes): if not remoteSig.init(remoteSigBytes):
raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes
if not remoteSig.verify(verifyPayload, remotePubKey): if not remoteSig.verify(verifyPayload, remotePubKey):
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.") raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
else: else:
trace "Remote signature verified", peer = $conn trace "Remote signature verified", peer = $conn
if initiator and not isNil(conn.peerInfo): if initiator and not isNil(conn.peerInfo):
let pid = PeerID.init(remotePubKey) let pid = PeerID.init(remotePubKey)
if not conn.peerInfo.peerId.validate(): if not conn.peerInfo.peerId.validate():
raise newException(NoiseHandshakeError, "Failed to validate peerId.") raise newException(NoiseHandshakeError, "Failed to validate peerId.")
if pid.isErr or pid.get() != conn.peerInfo.peerId: if pid.isErr or pid.get() != conn.peerInfo.peerId:
var var
failedKey: PublicKey failedKey: PublicKey
discard extractPublicKey(conn.peerInfo.peerId, failedKey) discard extractPublicKey(conn.peerInfo.peerId, failedKey)
debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey
raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId) raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId)
var secure = NoiseConnection.init(conn, var tmp = NoiseConnection.init(
PeerInfo.init(remotePubKey), conn, PeerInfo.init(remotePubKey), conn.observedAddr)
conn.observedAddr)
if initiator: if initiator:
secure.readCs = handshakeRes.cs2 tmp.readCs = handshakeRes.cs2
secure.writeCs = handshakeRes.cs1 tmp.writeCs = handshakeRes.cs1
else: else:
secure.readCs = handshakeRes.cs1 tmp.readCs = handshakeRes.cs1
secure.writeCs = handshakeRes.cs2 tmp.writeCs = handshakeRes.cs2
tmp
finally:
burnMem(handshakeRes)
trace "Noise handshake completed!", initiator, peer = $secure.peerInfo trace "Noise handshake completed!", initiator, peer = $secure.peerInfo
return secure return secure
method close*(s: NoiseConnection) {.async.} =
await procCall SecureConn(s).close()
burnMem(s.readCs)
burnMem(s.writeCs)
method init*(p: Noise) {.gcsafe.} = method init*(p: Noise) {.gcsafe.} =
procCall Secure(p).init() procCall Secure(p).init()
p.codec = NoiseCodec p.codec = NoiseCodec
@ -491,7 +510,7 @@ proc newNoise*(
rng: rng, rng: rng,
outgoing: outgoing, outgoing: outgoing,
localPrivateKey: privateKey, localPrivateKey: privateKey,
localPublicKey: privateKey.getKey().tryGet(), localPublicKey: privateKey.getKey().tryGet().getBytes().tryGet(),
noiseKeys: genKeyPair(rng[]), noiseKeys: genKeyPair(rng[]),
commonPrologue: commonPrologue, commonPrologue: commonPrologue,
) )

View File

@ -59,10 +59,9 @@ proc handleConn*(s: Secure,
conn: Connection, conn: Connection,
initiator: bool): Future[Connection] {.async, gcsafe.} = initiator: bool): Future[Connection] {.async, gcsafe.} =
var sconn = await s.handshake(conn, initiator) var sconn = await s.handshake(conn, initiator)
if not isNil(sconn):
conn.closeEvent.wait() conn.closeEvent.wait()
.addCallback do(udata: pointer = nil): .addCallback do(udata: pointer = nil):
if not(isNil(sconn)):
asyncCheck sconn.close() asyncCheck sconn.close()
return sconn return sconn

View File

@ -61,6 +61,11 @@ template data*(v: StreamSeq): openArray[byte] =
# TODO a double-hash comment here breaks compile (!) # TODO a double-hash comment here breaks compile (!)
v.buf.toOpenArray(v.rpos, v.wpos - 1) v.buf.toOpenArray(v.rpos, v.wpos - 1)
template toOpenArray*(v: StreamSeq, b, e: int): openArray[byte] =
# Data that is ready to be consumed
# TODO a double-hash comment here breaks compile (!)
v.buf.toOpenArray(v.rpos + b, v.rpos + e - b)
func consume*(v: var StreamSeq, n: int) = func consume*(v: var StreamSeq, n: int) =
## Mark `n` bytes that were returned via `data` as consumed ## Mark `n` bytes that were returned via `data` as consumed
v.rpos += n v.rpos += n
@ -71,3 +76,10 @@ func consumeTo*(v: var StreamSeq, buf: var openArray[byte]): int =
copyMem(addr buf[0], addr v.buf[v.rpos], bytes) copyMem(addr buf[0], addr v.buf[v.rpos], bytes)
v.consume(bytes) v.consume(bytes)
bytes bytes
func clear*(v: var StreamSeq) =
v.consume(v.len)
func assign*(v: var StreamSeq, buf: openArray[byte]) =
v.clear()
v.add(buf)