From 45c089ff0d81938a51b97a82c4504871a9c286fb Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Thu, 9 Jul 2020 10:53:19 +0200 Subject: [PATCH] noise updates (#255) * clear secrets explicitly * simplify keygen * avoid some trivial memory allocations * fix little endian encoding of nonce --- libp2p/crypto/curve25519.nim | 13 +- libp2p/crypto/ecnist.nim | 10 +- libp2p/protocols/secure/noise.nim | 277 +++++++++++++++-------------- libp2p/protocols/secure/secure.nim | 7 +- libp2p/stream/streamseq.nim | 12 ++ 5 files changed, 173 insertions(+), 146 deletions(-) diff --git a/libp2p/crypto/curve25519.nim b/libp2p/crypto/curve25519.nim index fbb84a3b3..18580d790 100644 --- a/libp2p/crypto/curve25519.nim +++ b/libp2p/crypto/curve25519.nim @@ -105,10 +105,13 @@ proc mulgen*(_: type[Curve25519], dst: var Curve25519Key, point: Curve25519Key) proc public*(private: Curve25519Key): Curve25519Key = Curve25519.mulgen(result, private) -proc random*(_: type[Curve25519Key], rng: var BrHmacDrbgContext): Result[Curve25519Key, Curve25519Error] = +proc random*(_: type[Curve25519Key], rng: var BrHmacDrbgContext): Curve25519Key = var res: Curve25519Key let defaultBrEc = brEcGetDefault() - if brEcKeygen(addr rng.vtable, defaultBrEc, nil, addr res[0], EC_curve25519) != Curve25519KeySize: - err(Curver25519GenError) - else: - ok(res) + let len = brEcKeygen( + addr rng.vtable, defaultBrEc, nil, addr res[0], EC_curve25519) + # Per bearssl documentation, the keygen only fails if the curve is + # unrecognised - + doAssert len == Curve25519KeySize, "Could not generate curve" + + res diff --git a/libp2p/crypto/ecnist.nim b/libp2p/crypto/ecnist.nim index c33ea90c8..30475ef4b 100644 --- a/libp2p/crypto/ecnist.nim +++ b/libp2p/crypto/ecnist.nim @@ -39,11 +39,11 @@ const type EcPrivateKey* = ref object - buffer*: seq[byte] + buffer*: array[BR_EC_KBUF_PRIV_MAX_SIZE, byte] key*: BrEcPrivateKey EcPublicKey* = ref object - buffer*: seq[byte] + buffer*: array[BR_EC_KBUF_PUB_MAX_SIZE, byte] key*: BrEcPublicKey EcKeyPair* = object @@ -237,7 +237,6 @@ proc random*( ## secp521r1). var ecimp = brEcGetDefault() var res = new EcPrivateKey - res.buffer = newSeq[byte](BR_EC_KBUF_PRIV_MAX_SIZE) if brEcKeygen(addr rng.vtable, ecimp, addr res.key, addr res.buffer[0], cast[cint](kind)) == 0: @@ -254,7 +253,6 @@ proc getKey*(seckey: EcPrivateKey): EcResult[EcPublicKey] = if seckey.key.curve in EcSupportedCurvesCint: var length = getPublicKeyLength(cast[EcCurveKind](seckey.key.curve)) var res = new EcPublicKey - res.buffer = newSeq[byte](length) if brEcComputePublicKey(ecimp, addr res.key, addr res.buffer[0], unsafeAddr seckey.key) == 0: err(EcKeyIncorrectError) @@ -621,7 +619,6 @@ proc init*(key: var EcPrivateKey, data: openarray[byte]): Result[void, Asn1Error if checkScalar(raw.toOpenArray(), curve) == 1'u32: key = new EcPrivateKey - key.buffer = newSeq[byte](raw.length) copyMem(addr key.buffer[0], addr raw.buffer[raw.offset], raw.length) key.key.x = cast[ptr cuchar](addr key.buffer[0]) key.key.xlen = raw.length @@ -681,7 +678,6 @@ proc init*(pubkey: var EcPublicKey, data: openarray[byte]): Result[void, Asn1Err if checkPublic(raw.toOpenArray(), curve) != 0: pubkey = new EcPublicKey - pubkey.buffer = newSeq[byte](raw.length) copyMem(addr pubkey.buffer[0], addr raw.buffer[raw.offset], raw.length) pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0]) pubkey.key.qlen = raw.length @@ -769,7 +765,6 @@ proc initRaw*(key: var EcPrivateKey, data: openarray[byte]): bool = if checkScalar(data, curve) == 1'u32: let length = len(data) key = new EcPrivateKey - key.buffer = newSeq[byte](length) copyMem(addr key.buffer[0], unsafeAddr data[0], length) key.key.x = cast[ptr cuchar](addr key.buffer[0]) key.key.xlen = length @@ -801,7 +796,6 @@ proc initRaw*(pubkey: var EcPublicKey, data: openarray[byte]): bool = if checkPublic(data, curve) != 0: let length = len(data) pubkey = new EcPublicKey - pubkey.buffer = newSeq[byte](length) copyMem(addr pubkey.buffer[0], unsafeAddr data[0], length) pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0]) pubkey.key.qlen = length diff --git a/libp2p/protocols/secure/noise.nim b/libp2p/protocols/secure/noise.nim index 929104bd2..fa14674bb 100644 --- a/libp2p/protocols/secure/noise.nim +++ b/libp2p/protocols/secure/noise.nim @@ -12,15 +12,13 @@ import chronicles import bearssl import stew/[endians2, byteutils] import nimcrypto/[utils, sha2, hmac] -import ../../stream/lpstream +import ../../stream/[connection, streamseq] import ../../peerid import ../../peerinfo import ../../protobuf/minprotobuf import ../../utility -import ../../stream/lpstream import secure, - ../../crypto/[crypto, chacha20poly1305, curve25519, hkdf], - ../../stream/bufferstream + ../../crypto/[crypto, chacha20poly1305, curve25519, hkdf] logScope: topics = "noise" @@ -34,7 +32,7 @@ const ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256" # Empty is a special value which indicates k has not yet been initialized. - EmptyKey: ChaChaPolyKey = [0.byte, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + EmptyKey = default(ChaChaPolyKey) NonceMax = uint64.high - 1 # max is reserved NoiseSize = 32 MaxPlainSize = int(uint16.high - NoiseSize - ChaChaPolyTag.len) @@ -72,7 +70,7 @@ type Noise* = ref object of Secure rng: ref BrHmacDrbgContext localPrivateKey: PrivateKey - localPublicKey: PublicKey + localPublicKey: seq[byte] noiseKeys: KeyPair commonPrologue: seq[byte] outgoing: bool @@ -89,7 +87,7 @@ type # Utility proc genKeyPair(rng: var BrHmacDrbgContext): KeyPair = - result.privateKey = Curve25519Key.random(rng).tryGet() + result.privateKey = Curve25519Key.random(rng) result.publicKey = result.privateKey.public() proc hashProtocol(name: string): MDigest[256] = @@ -110,12 +108,11 @@ proc dh(priv: Curve25519Key, pub: Curve25519Key): Curve25519Key = proc hasKey(cs: CipherState): bool = cs.k != EmptyKey -proc encryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte] = +proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] = var tag: ChaChaPolyTag nonce: ChaChaPolyNonce - np = cast[ptr uint64](addr nonce[4]) - np[] = state.n + nonce[4..<12] = toBytesLE(state.n) result = @data ChaChaPoly.encrypt(state.k, nonce, tag, result, ad) inc state.n @@ -124,13 +121,12 @@ proc encryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte] result &= tag trace "encryptWithAd", tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1 -proc decryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte] = +proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] = var - tagIn = data[^ChaChaPolyTag.len..data.high].intoChaChaPolyTag - tagOut = tagIn + tagIn = data.toOpenArray(data.len - ChaChaPolyTag.len, data.high).intoChaChaPolyTag + tagOut: ChaChaPolyTag nonce: ChaChaPolyNonce - np = cast[ptr uint64](addr nonce[4]) - np[] = state.n + nonce[4..<12] = toBytesLE(state.n) result = data[0..(data.high - ChaChaPolyTag.len)] ChaChaPoly.decrypt(state.k, nonce, tagOut, result, ad) trace "decryptWithAd", tagIn = tagIn.shortLog, tagOut = tagOut.shortLog, nonce = state.n @@ -156,7 +152,7 @@ proc mixKey(ss: var SymmetricState, ikm: ChaChaPolyKey) = ss.cs = CipherState(k: temp_keys[1]) trace "mixKey", key = ss.cs.k.shortLog -proc mixHash(ss: var SymmetricState; data: openarray[byte]) = +proc mixHash(ss: var SymmetricState; data: openArray[byte]) = var ctx: sha256 ctx.init() ctx.update(ss.h.data) @@ -165,7 +161,7 @@ proc mixHash(ss: var SymmetricState; data: openarray[byte]) = trace "mixHash", hash = ss.h.data.shortLog # We might use this for other handshake patterns/tokens -proc mixKeyAndHash(ss: var SymmetricState; ikm: openarray[byte]) {.used.} = +proc mixKeyAndHash(ss: var SymmetricState; ikm: openArray[byte]) {.used.} = var temp_keys: array[3, ChaChaPolyKey] sha256.hkdf(ss.ck, ikm, [], temp_keys) @@ -173,7 +169,7 @@ proc mixKeyAndHash(ss: var SymmetricState; ikm: openarray[byte]) {.used.} = ss.mixHash(temp_keys[1]) ss.cs = CipherState(k: temp_keys[2]) -proc encryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] = +proc encryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte] = # according to spec if key is empty leave plaintext if ss.cs.hasKey: result = ss.cs.encryptWithAd(ss.h.data, data) @@ -181,7 +177,7 @@ proc encryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] = result = @data ss.mixHash(result) -proc decryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] = +proc decryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte] = # according to spec if key is empty leave plaintext if ss.cs.hasKey: result = ss.cs.decryptWithAd(ss.h.data, data) @@ -202,13 +198,13 @@ template write_e: untyped = trace "noise write e" # Sets e (which must be empty) to GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key). hs.e = genKeyPair(p.rng[]) - msg &= hs.e.publicKey + msg.add hs.e.publicKey hs.ss.mixHash(hs.e.publicKey) template write_s: untyped = trace "noise write s" # Appends EncryptAndHash(s.public_key) to the buffer. - msg &= hs.ss.encryptAndHash(hs.s.publicKey) + msg.add hs.ss.encryptAndHash(hs.s.publicKey) template dh_ee: untyped = trace "noise dh ee" @@ -244,8 +240,8 @@ template read_e: untyped = raise newException(NoiseHandshakeError, "Noise E, expected more data") # Sets re (which must be empty) to the next DHLEN bytes from the message. Calls MixHash(re.public_key). - hs.re[0..Curve25519Key.high] = msg[0..Curve25519Key.high] - msg = msg[Curve25519Key.len..msg.high] + hs.re[0..Curve25519Key.high] = msg.toOpenArray(0, Curve25519Key.high) + msg.consume(Curve25519Key.len) hs.ss.mixHash(hs.re) template read_s: untyped = @@ -253,30 +249,33 @@ template read_s: untyped = # Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes otherwise. # Sets rs (which must be empty) to DecryptAndHash(temp). let - temp = + rsLen = if hs.ss.cs.hasKey: if msg.len < Curve25519Key.len + ChaChaPolyTag.len: raise newException(NoiseHandshakeError, "Noise S, expected more data") - msg[0..Curve25519Key.high + ChaChaPolyTag.len] + Curve25519Key.len + ChaChaPolyTag.len else: if msg.len < Curve25519Key.len: raise newException(NoiseHandshakeError, "Noise S, expected more data") - msg[0..Curve25519Key.high] - msg = msg[temp.len..msg.high] - let plain = hs.ss.decryptAndHash(temp) - hs.rs[0..Curve25519Key.high] = plain + Curve25519Key.len + hs.rs[0..Curve25519Key.high] = + hs.ss.decryptAndHash(msg.toOpenArray(0, rsLen - 1)) + + msg.consume(rsLen) proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} = var besize: array[2, byte] await sconn.readExactly(addr besize[0], besize.len) let size = uint16.fromBytesBE(besize).int trace "receiveHSMessage", size + if size == 0: + return + var buffer = newSeq[byte](size) - if buffer.len > 0: - await sconn.readExactly(addr buffer[0], buffer.len) + await sconn.readExactly(addr buffer[0], buffer.len) return buffer -proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} = +proc sendHSMessage(sconn: Connection; buf: openArray[byte]): Future[void] = var lesize = buf.len.uint16 besize = lesize.toBytesBE @@ -284,97 +283,106 @@ proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} = trace "sendHSMessage", size = lesize outbuf &= besize outbuf &= buf - await sconn.write(outbuf) + sconn.write(outbuf) -proc handshakeXXOutbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Future[HandshakeResult] {.async.} = +proc handshakeXXOutbound( + p: Noise, conn: Connection, + p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} = const initiator = true - var hs = HandshakeState.init() - p2psecret = p2pProof.buffer - hs.ss.mixHash(p.commonPrologue) - hs.s = p.noiseKeys + try: - # -> e - var msg: seq[byte] + hs.ss.mixHash(p.commonPrologue) + hs.s = p.noiseKeys - write_e() + # -> e + var msg: StreamSeq - # IK might use this btw! - msg &= hs.ss.encryptAndHash(@[]) + write_e() - await conn.sendHSMessage(msg) + # IK might use this btw! + msg.add hs.ss.encryptAndHash([]) - # <- e, ee, s, es + await conn.sendHSMessage(msg.data) - msg = await conn.receiveHSMessage() + # <- e, ee, s, es - read_e() - dh_ee() - read_s() - dh_es() + msg.assign(await conn.receiveHSMessage()) - let remoteP2psecret = hs.ss.decryptAndHash(msg) + read_e() + dh_ee() + read_s() + dh_es() - # -> s, se + let remoteP2psecret = hs.ss.decryptAndHash(msg.data) + msg.clear() - msg.setLen(0) + # -> s, se - write_s() - dh_se() + write_s() + dh_se() - # last payload must follow the ecrypted way of sending - msg &= hs.ss.encryptAndHash(p2psecret) + # last payload must follow the encrypted way of sending + msg.add hs.ss.encryptAndHash(p2psecret) - await conn.sendHSMessage(msg) + await conn.sendHSMessage(msg.data) - let (cs1, cs2) = hs.ss.split() - return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs) + let (cs1, cs2) = hs.ss.split() + return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs) + finally: + burnMem(hs) -proc handshakeXXInbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Future[HandshakeResult] {.async.} = +proc handshakeXXInbound( + p: Noise, conn: Connection, + p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} = const initiator = false var hs = HandshakeState.init() - p2psecret = p2pProof.buffer - hs.ss.mixHash(p.commonPrologue) - hs.s = p.noiseKeys + try: + hs.ss.mixHash(p.commonPrologue) + hs.s = p.noiseKeys - # -> e + # -> e - var msg = await conn.receiveHSMessage() + var msg: StreamSeq + msg.add(await conn.receiveHSMessage()) - read_e() + read_e() - # we might use this early data one day, keeping it here for clarity - let earlyData {.used.} = hs.ss.decryptAndHash(msg) + # we might use this early data one day, keeping it here for clarity + let earlyData {.used.} = hs.ss.decryptAndHash(msg.data) - # <- e, ee, s, es + # <- e, ee, s, es - msg.setLen(0) + msg.consume(msg.len) - write_e() - dh_ee() - write_s() - dh_es() + write_e() + dh_ee() + write_s() + dh_es() - msg &= hs.ss.encryptAndHash(p2psecret) + msg.add hs.ss.encryptAndHash(p2psecret) - await conn.sendHSMessage(msg) + await conn.sendHSMessage(msg.data) + msg.clear() - # -> s, se + # -> s, se - msg = await conn.receiveHSMessage() + msg.add(await conn.receiveHSMessage()) - read_s() - dh_se() + read_s() + dh_se() - let remoteP2psecret = hs.ss.decryptAndHash(msg) - - let (cs1, cs2) = hs.ss.split() - return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs) + let + remoteP2psecret = hs.ss.decryptAndHash(msg.data) + (cs1, cs2) = hs.ss.split() + return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs) + finally: + burnMem(hs) method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} = while true: # Discard 0-length payloads @@ -399,7 +407,8 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async. while left > 0: let chunkSize = if left > MaxPlainSize: MaxPlainSize else: left - cipher = sconn.writeCs.encryptWithAd([], message.toOpenArray(offset, offset + chunkSize - 1)) + cipher = sconn.writeCs.encryptWithAd( + [], message.toOpenArray(offset, offset + chunkSize - 1)) left = left - chunkSize offset = offset + chunkSize var @@ -421,65 +430,75 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon var libp2pProof = initProtoBuffer() - libp2pProof.write(initProtoField(1, p.localPublicKey.getBytes.tryGet())) + libp2pProof.write(initProtoField(1, p.localPublicKey)) libp2pProof.write(initProtoField(2, signedPayload.getBytes())) # data field also there but not used! libp2pProof.finish() - let handshakeRes = + var handshakeRes = if initiator: - await handshakeXXOutbound(p, conn, libp2pProof) + await handshakeXXOutbound(p, conn, libp2pProof.buffer) else: - await handshakeXXInbound(p, conn, libp2pProof) + await handshakeXXInbound(p, conn, libp2pProof.buffer) - var - remoteProof = initProtoBuffer(handshakeRes.remoteP2psecret) - remotePubKey: PublicKey - remotePubKeyBytes: seq[byte] - remoteSig: Signature - remoteSigBytes: seq[byte] + var secure = try: + var + remoteProof = initProtoBuffer(handshakeRes.remoteP2psecret) + remotePubKey: PublicKey + remotePubKeyBytes: seq[byte] + remoteSig: Signature + remoteSigBytes: seq[byte] - if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0: - raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") - if remoteProof.getLengthValue(2, remoteSigBytes) <= 0: - raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") + if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0: + raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") + if remoteProof.getLengthValue(2, remoteSigBytes) <= 0: + raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") - if not remotePubKey.init(remotePubKeyBytes): - raise newException(NoiseHandshakeError, "Failed to decode remote public key. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") - if not remoteSig.init(remoteSigBytes): - raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") + if not remotePubKey.init(remotePubKeyBytes): + raise newException(NoiseHandshakeError, "Failed to decode remote public key. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") + if not remoteSig.init(remoteSigBytes): + raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") - let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes - if not remoteSig.verify(verifyPayload, remotePubKey): - raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.") - else: - trace "Remote signature verified", peer = $conn + let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes + if not remoteSig.verify(verifyPayload, remotePubKey): + raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.") + else: + trace "Remote signature verified", peer = $conn - if initiator and not isNil(conn.peerInfo): - let pid = PeerID.init(remotePubKey) - if not conn.peerInfo.peerId.validate(): - raise newException(NoiseHandshakeError, "Failed to validate peerId.") - if pid.isErr or pid.get() != conn.peerInfo.peerId: - var - failedKey: PublicKey - discard extractPublicKey(conn.peerInfo.peerId, failedKey) - debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey - raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId) + if initiator and not isNil(conn.peerInfo): + let pid = PeerID.init(remotePubKey) + if not conn.peerInfo.peerId.validate(): + raise newException(NoiseHandshakeError, "Failed to validate peerId.") + if pid.isErr or pid.get() != conn.peerInfo.peerId: + var + failedKey: PublicKey + discard extractPublicKey(conn.peerInfo.peerId, failedKey) + debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey + raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId) - var secure = NoiseConnection.init(conn, - PeerInfo.init(remotePubKey), - conn.observedAddr) - if initiator: - secure.readCs = handshakeRes.cs2 - secure.writeCs = handshakeRes.cs1 - else: - secure.readCs = handshakeRes.cs1 - secure.writeCs = handshakeRes.cs2 + var tmp = NoiseConnection.init( + conn, PeerInfo.init(remotePubKey), conn.observedAddr) + + if initiator: + tmp.readCs = handshakeRes.cs2 + tmp.writeCs = handshakeRes.cs1 + else: + tmp.readCs = handshakeRes.cs1 + tmp.writeCs = handshakeRes.cs2 + tmp + finally: + burnMem(handshakeRes) trace "Noise handshake completed!", initiator, peer = $secure.peerInfo return secure +method close*(s: NoiseConnection) {.async.} = + await procCall SecureConn(s).close() + + burnMem(s.readCs) + burnMem(s.writeCs) + method init*(p: Noise) {.gcsafe.} = procCall Secure(p).init() p.codec = NoiseCodec @@ -491,7 +510,7 @@ proc newNoise*( rng: rng, outgoing: outgoing, localPrivateKey: privateKey, - localPublicKey: privateKey.getKey().tryGet(), + localPublicKey: privateKey.getKey().tryGet().getBytes().tryGet(), noiseKeys: genKeyPair(rng[]), commonPrologue: commonPrologue, ) diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index 235de6e33..7f05fe49e 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -59,10 +59,9 @@ proc handleConn*(s: Secure, conn: Connection, initiator: bool): Future[Connection] {.async, gcsafe.} = var sconn = await s.handshake(conn, initiator) - - conn.closeEvent.wait() - .addCallback do(udata: pointer = nil): - if not(isNil(sconn)): + if not isNil(sconn): + conn.closeEvent.wait() + .addCallback do(udata: pointer = nil): asyncCheck sconn.close() return sconn diff --git a/libp2p/stream/streamseq.nim b/libp2p/stream/streamseq.nim index bbb44aeba..30f62ad7f 100644 --- a/libp2p/stream/streamseq.nim +++ b/libp2p/stream/streamseq.nim @@ -61,6 +61,11 @@ template data*(v: StreamSeq): openArray[byte] = # TODO a double-hash comment here breaks compile (!) v.buf.toOpenArray(v.rpos, v.wpos - 1) +template toOpenArray*(v: StreamSeq, b, e: int): openArray[byte] = + # Data that is ready to be consumed + # TODO a double-hash comment here breaks compile (!) + v.buf.toOpenArray(v.rpos + b, v.rpos + e - b) + func consume*(v: var StreamSeq, n: int) = ## Mark `n` bytes that were returned via `data` as consumed v.rpos += n @@ -71,3 +76,10 @@ func consumeTo*(v: var StreamSeq, buf: var openArray[byte]): int = copyMem(addr buf[0], addr v.buf[v.rpos], bytes) v.consume(bytes) bytes + +func clear*(v: var StreamSeq) = + v.consume(v.len) + +func assign*(v: var StreamSeq, buf: openArray[byte]) = + v.clear() + v.add(buf)