From b7e5d1122ce3dd5140788165ed86ebcb0a6002d8 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Wed, 16 Sep 2020 11:55:25 +0200 Subject: [PATCH] cleanups (#366) * reuse connection timeout for noise handshake (avoid extra timer) * enforce nbytes > 0 for readOnce * avoid some unnecessary memory zeroing * simplify noise * fix dumping when noise splits message --- libp2p/protocols/secure/noise.nim | 73 +++++++++++++++--------------- libp2p/protocols/secure/secure.nim | 7 ++- libp2p/stream/bufferstream.nim | 2 + 3 files changed, 42 insertions(+), 40 deletions(-) diff --git a/libp2p/protocols/secure/noise.nim b/libp2p/protocols/secure/noise.nim index 0ff11fdd1..e4a4628c2 100644 --- a/libp2p/protocols/secure/noise.nim +++ b/libp2p/protocols/secure/noise.nim @@ -276,28 +276,33 @@ template read_s: untyped = msg.consume(rsLen) -proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} = - var besize: array[2, byte] - await sconn.readExactly(addr besize[0], besize.len).wait(HandshakeTimeout) +proc readFrame(sconn: Connection): Future[seq[byte]] {.async.} = + var besize {.noinit.}: array[2, byte] + await sconn.readExactly(addr besize[0], besize.len) let size = uint16.fromBytesBE(besize).int - trace "receiveHSMessage", size + trace "readFrame", sconn, size if size == 0: return - var buffer = newSeq[byte](size) - await sconn.readExactly(addr buffer[0], buffer.len).wait(HandshakeTimeout) + var buffer = newSeqUninitialized[byte](size) + await sconn.readExactly(addr buffer[0], buffer.len) return buffer -proc sendHSMessage(sconn: Connection; buf: openArray[byte]): Future[void] = +proc writeFrame(sconn: Connection, buf: openArray[byte]): Future[void] = + doAssert buf.len <= uint16.high.int var lesize = buf.len.uint16 besize = lesize.toBytesBE outbuf = newSeqOfCap[byte](besize.len + buf.len) - trace "sendHSMessage", size = lesize + trace "writeFrame", sconn, size = lesize, data = shortLog(buf) outbuf &= besize outbuf &= buf sconn.write(outbuf) +proc receiveHSMessage(sconn: Connection): Future[seq[byte]] = readFrame(sconn) +proc sendHSMessage(sconn: Connection, buf: openArray[byte]): Future[void] = + writeFrame(sconn, buf) + proc handshakeXXOutbound( p: Noise, conn: Connection, p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} = @@ -399,23 +404,18 @@ proc handshakeXXInbound( method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} = while true: # Discard 0-length payloads - var besize: array[2, byte] - await sconn.stream.readExactly(addr besize[0], besize.len) - let size = uint16.fromBytesBE(besize).int # Cannot overflow - trace "receiveEncryptedMessage", size, sconn - if size > 0: - var buffer = newSeq[byte](size) - await sconn.stream.readExactly(addr buffer[0], buffer.len) - when defined(libp2p_dump): - let res = sconn.readCs.decryptWithAd([], buffer) - dumpMessage(sconn, FlowDirection.Incoming, res) + let frame = await sconn.stream.readFrame() + sconn.activity = true + if frame.len > 0: + let res = sconn.readCs.decryptWithAd([], frame) + if res.len > 0: + when defined(libp2p_dump): + dumpMessage(sconn, FlowDirection.Incoming, res) return res - else: - return sconn.readCs.decryptWithAd([], buffer) - else: - when defined(libp2p_dump): - dumpMessage(sconn, FlowDirection.Incoming, []) - trace "Received 0-length message", sconn + + when defined(libp2p_dump): + dumpMessage(sconn, FlowDirection.Incoming, []) + trace "Received 0-length message", sconn method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} = if message.len == 0: @@ -426,28 +426,27 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async. offset = 0 while left > 0: let - chunkSize = if left > MaxPlainSize: MaxPlainSize else: left + chunkSize = min(MaxPlainSize, left) cipher = sconn.writeCs.encryptWithAd( [], message.toOpenArray(offset, offset + chunkSize - 1)) - left = left - chunkSize - offset = offset + chunkSize - var - lesize = cipher.len.uint16 - besize = lesize.toBytesBE - outbuf = newSeqOfCap[byte](cipher.len + 2) - trace "sendEncryptedMessage", sconn, size = lesize, left, offset - outbuf &= besize - outbuf &= cipher - await sconn.stream.write(outbuf) + + await sconn.stream.writeFrame(cipher) when defined(libp2p_dump): - dumpMessage(sconn, FlowDirection.Outgoing, message) + dumpMessage( + sconn, FlowDirection.Outgoing, + message.toOpenArray(offset, offset + chunkSize - 1)) + left = left - chunkSize + offset = offset + chunkSize sconn.activity = true method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} = trace "Starting Noise handshake", conn, initiator + let timeout = conn.timeout + conn.timeout = HandshakeTimeout + # https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages let signedPayload = p.localPrivateKey.sign( @@ -524,6 +523,8 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon trace "Noise handshake completed!", initiator, peer = shortLog(secure.peerInfo) + conn.timeout = timeout + return secure method close*(s: NoiseConnection) {.async.} = diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index 7caf6686c..e477be1ea 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -108,15 +108,14 @@ method init*(s: Secure) {.gcsafe.} = method secure*(s: Secure, conn: Connection, initiator: bool): - Future[Connection] {.async, base, gcsafe.} = - result = await s.handleConn(conn, initiator) + Future[Connection] {.base, gcsafe.} = + s.handleConn(conn, initiator) method readOnce*(s: SecureConn, pbytes: pointer, nbytes: int): Future[int] {.async, gcsafe.} = - if nbytes == 0: - return 0 + doAssert(nbytes > 0, "nbytes must be positive integer") if s.buf.data().len() == 0: let buf = await s.readMessage() diff --git a/libp2p/stream/bufferstream.nim b/libp2p/stream/bufferstream.nim index 26056a752..d763042a7 100644 --- a/libp2p/stream/bufferstream.nim +++ b/libp2p/stream/bufferstream.nim @@ -128,6 +128,8 @@ method readOnce*(s: BufferStream, pbytes: pointer, nbytes: int): Future[int] {.async.} = + doAssert(nbytes > 0, "nbytes must be positive integer") + if s.isEof and s.readBuf.len() == 0: raise newLPStreamEOFError()