diff --git a/libp2p/muxers/mplex/coder.nim b/libp2p/muxers/mplex/coder.nim index 7f05a8e..d5e328d 100644 --- a/libp2p/muxers/mplex/coder.nim +++ b/libp2p/muxers/mplex/coder.nim @@ -56,10 +56,7 @@ proc readMsg*(conn: Connection): Future[Msg] {.async, gcsafe.} = proc writeMsg*(conn: Connection, id: uint64, msgType: MessageType, - data: seq[byte] = @[]) {.async, gcsafe.} = - if conn.closed: - return # No point in trying to write to an already-closed connection - + data: seq[byte] = @[]): Future[void] = var left = data.len offset = 0 @@ -81,17 +78,9 @@ proc writeMsg*(conn: Connection, trace "writing mplex message", conn, id, msgType, data = data.len, encoded = buf.buffer.len - try: - # Write all chunks in a single write to avoid async races where a close - # message gets written before some of the chunks - await conn.write(buf.buffer) - trace "wrote mplex", conn, id, msgType - except CatchableError as exc: - # If the write to the underlying connection failed it should be closed so - # that the other channels are notified as well - trace "failed write", conn, id, msg = exc.msg - await conn.close() - raise exc + # Write all chunks in a single write to avoid async races where a close + # message gets written before some of the chunks + conn.write(buf.buffer) proc writeMsg*(conn: Connection, id: uint64, diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index 9cc3dd3..48c162b 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -50,8 +50,6 @@ type resetCode*: MessageType # cached in/out reset code writes*: int # In-flight writes -proc open*(s: LPChannel) {.async, gcsafe.} - func shortLog*(s: LPChannel): auto = if s.isNil: "LPChannel(nil)" elif s.conn.peerInfo.isNil: $s.oid @@ -62,8 +60,14 @@ chronicles.formatIt(LPChannel): shortLog(it) proc open*(s: LPChannel) {.async, gcsafe.} = trace "Opening channel", s, conn = s.conn - await s.conn.writeMsg(s.id, MessageType.New, s.name) - s.isOpen = true + if s.conn.isClosed: + return + try: + await s.conn.writeMsg(s.id, MessageType.New, s.name) + s.isOpen = true + except CatchableError as exc: + await s.conn.close() + raise exc method closed*(s: LPChannel): bool = s.closedLocal @@ -88,10 +92,11 @@ proc reset*(s: LPChannel) {.async, gcsafe.} = # If the connection is still active, notify the other end proc resetMessage() {.async.} = try: - trace "sending reset message", s, conn = s.conn - await s.conn.writeMsg(s.id, s.resetCode) # write reset + trace "sending reset message", s, conn = s.conn + await s.conn.writeMsg(s.id, s.resetCode) # write reset except CatchableError as exc: - # No cancellations, errors handled in writeMsg + # No cancellations + await s.conn.close() trace "Can't send reset message", s, conn = s.conn, msg = exc.msg asyncSpawn resetMessage() @@ -115,10 +120,12 @@ method close*(s: LPChannel) {.async, gcsafe.} = try: await s.conn.writeMsg(s.id, s.closeCode) # write close except CancelledError as exc: + await s.conn.close() raise exc except CatchableError as exc: # It's harmless that close message cannot be sent - the connection is # likely down already + await s.conn.close() trace "Cannot send close message", s, id = s.id, msg = exc.msg await s.closeUnderlying() # maybe already eofed diff --git a/libp2p/protocols/secure/noise.nim b/libp2p/protocols/secure/noise.nim index 38e45a4..8f4832a 100644 --- a/libp2p/protocols/secure/noise.nim +++ b/libp2p/protocols/secure/noise.nim @@ -121,18 +121,28 @@ proc dh(priv: Curve25519Key, pub: Curve25519Key): Curve25519Key = proc hasKey(cs: CipherState): bool = cs.k != EmptyKey -proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] = - var - tag: ChaChaPolyTag - nonce: ChaChaPolyNonce +proc encrypt( + state: var CipherState, data: var openArray[byte], + ad: openArray[byte]): ChaChaPolyTag {.noinit.} = + var nonce: ChaChaPolyNonce nonce[4..<12] = toBytesLE(state.n) - result = @data - ChaChaPoly.encrypt(state.k, nonce, tag, result, ad) + + ChaChaPoly.encrypt(state.k, nonce, result, data, ad) + inc state.n if state.n > NonceMax: raise newException(NoiseNonceMaxError, "Noise max nonce value reached") - result &= tag - trace "encryptWithAd", tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1 + +proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] = + result = newSeqOfCap[byte](data.len + sizeof(ChachaPolyTag)) + result.add(data) + + let tag = encrypt(state, result, ad) + + result.add(tag) + + trace "encryptWithAd", + tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1 proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] = var @@ -417,20 +427,47 @@ method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} = dumpMessage(sconn, FlowDirection.Incoming, []) trace "Received 0-length message", sconn + +proc encryptFrame( + sconn: NoiseConnection, cipherFrame: var openArray[byte], src: openArray[byte]) = + # Frame consists of length + cipher data + tag + doAssert src.len <= MaxPlainSize + doAssert cipherFrame.len == 2 + src.len + sizeof(ChaChaPolyTag) + + cipherFrame[0..<2] = toBytesBE(uint16(src.len + sizeof(ChaChaPolyTag))) + + copyMem(addr cipherFrame[2], unsafeAddr src[0], src.len()) + + let tag = encrypt( + sconn.writeCs, cipherFrame.toOpenArray(2, 2 + src.len() - 1), []) + + copyMem( + addr cipherFrame[cipherFrame.len - sizeof(tag)], unsafeAddr tag[0], + sizeof(tag)) + method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} = if message.len == 0: return + const FramingSize = 2 + sizeof(ChaChaPolyTag) + + let + frames = (message.len + MaxPlainSize - 1) div MaxPlainSize + var + cipherFrames = newSeqUninitialized[byte](message.len + frames * FramingSize) left = message.len offset = 0 + woffset = 0 + while left > 0: let chunkSize = min(MaxPlainSize, left) - cipher = sconn.writeCs.encryptWithAd( - [], message.toOpenArray(offset, offset + chunkSize - 1)) - await sconn.stream.writeFrame(cipher) + encryptFrame( + sconn, + cipherFrames.toOpenArray(woffset, woffset + chunkSize + FramingSize - 1), + message.toOpenArray(offset, offset + chunkSize - 1)) when defined(libp2p_dump): dumpMessage( @@ -438,9 +475,12 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async. message.toOpenArray(offset, offset + chunkSize - 1)) left = left - chunkSize - offset = offset + chunkSize + offset += chunkSize + woffset += chunkSize + FramingSize sconn.activity = true + await sconn.stream.write(cipherFrames) + method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} = trace "Starting Noise handshake", conn, initiator @@ -529,8 +569,8 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon return secure -method close*(s: NoiseConnection) {.async.} = - await procCall SecureConn(s).close() +method closeImpl*(s: NoiseConnection) {.async.} = + await procCall SecureConn(s).closeImpl() burnMem(s.readCs) burnMem(s.writeCs) diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index 61ef59d..22395ee 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -56,29 +56,29 @@ method initStream*(s: SecureConn) = procCall Connection(s).initStream() -method close*(s: SecureConn) {.async.} = +method closeImpl*(s: SecureConn) {.async.} = trace "Closing secure conn", s, dir = s.dir if not(isNil(s.stream)): await s.stream.close() - await procCall Connection(s).close() + await procCall Connection(s).closeImpl() method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} = doAssert(false, "Not implemented!") -method handshake(s: Secure, - conn: Connection, - initiator: bool): Future[SecureConn] {.async, base.} = +method handshake*(s: Secure, + conn: Connection, + initiator: bool): Future[SecureConn] {.async, base.} = doAssert(false, "Not implemented!") -proc handleConn*(s: Secure, +proc handleConn(s: Secure, conn: Connection, - initiator: bool): Future[Connection] {.async, gcsafe.} = + initiator: bool): Future[Connection] {.async.} = var sconn = await s.handshake(conn, initiator) proc cleanup() {.async.} = try: - let futs = @[conn.join(), sconn.join()] + let futs = [conn.join(), sconn.join()] await futs[0] or futs[1] for f in futs: if not f.finished: await f.cancelAndWait() # cancel outstanding join() @@ -90,7 +90,7 @@ proc handleConn*(s: Secure, # do not need to propagate CancelledError. discard except CatchableError as exc: - trace "error cleaning up secure connection", err = exc.msg, sconn + debug "error cleaning up secure connection", err = exc.msg, sconn if not isNil(sconn): # All the errors are handled inside `cleanup()` procedure. @@ -98,10 +98,10 @@ proc handleConn*(s: Secure, return sconn -method init*(s: Secure) {.gcsafe.} = +method init*(s: Secure) = procCall LPProtocol(s).init() - proc handle(conn: Connection, proto: string) {.async, gcsafe.} = + proc handle(conn: Connection, proto: string) {.async.} = trace "handling connection upgrade", proto, conn try: # We don't need the result but we @@ -121,36 +121,34 @@ method init*(s: Secure) {.gcsafe.} = method secure*(s: Secure, conn: Connection, initiator: bool): - Future[Connection] {.base, gcsafe.} = + Future[Connection] {.base.} = s.handleConn(conn, initiator) method readOnce*(s: SecureConn, pbytes: pointer, nbytes: int): - Future[int] {.async, gcsafe.} = + Future[int] {.async.} = doAssert(nbytes > 0, "nbytes must be positive integer") - if s.buf.data().len() == 0: - let (buf, err) = try: - (await s.readMessage(), nil) - except CatchableError as exc: - (@[], exc) + if s.isEof: + raise newLPStreamEOFError() - if not isNil(err): - if not (err of LPStreamEOFError): - debug "Error while reading message from secure connection, closing.", - error=err.name, - message=err.msg, - connection=s + if s.buf.data().len() == 0: + try: + let buf = await s.readMessage() # Always returns >0 bytes or raises + s.activity = true + s.buf.add(buf) + except LPStreamEOFError as err: + s.isEof = true + await s.close() + raise err + except CatchableError as err: + debug "Error while reading message from secure connection, closing.", + error = err.name, + message = err.msg, + connection = s await s.close() raise err - s.activity = true - - if buf.len == 0: - raise newLPStreamIncompleteError() - - s.buf.add(buf) - var p = cast[ptr UncheckedArray[byte]](pbytes) return s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1))