From 3effb95f10d5019f971f36a6f8c02d5dfbbdd7d2 Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Fri, 27 Mar 2020 08:25:52 -0600 Subject: [PATCH] close underlying bufferstream in lpchannel --- libp2p/connection.nim | 15 ++++-- libp2p/multistream.nim | 2 +- libp2p/muxers/mplex/lpchannel.nim | 70 ++++++++++++------------- libp2p/muxers/mplex/mplex.nim | 2 +- libp2p/protocols/identify.nim | 3 +- libp2p/protocols/pubsub/pubsubpeer.nim | 4 +- libp2p/protocols/secure/secure.nim | 5 +- libp2p/stream/bufferstream.nim | 12 +++-- libp2p/stream/chronosstream.nim | 72 +++++++------------------- libp2p/stream/lpstream.nim | 26 +++++++--- libp2p/switch.nim | 2 +- libp2p/transports/transport.nim | 2 +- tests/testmplex.nim | 37 ++++++------- tests/testswitch.nim | 56 ++++++++++++-------- 14 files changed, 157 insertions(+), 151 deletions(-) diff --git a/libp2p/connection.nim b/libp2p/connection.nim index 64b3066fe..2d10c7640 100644 --- a/libp2p/connection.nim +++ b/libp2p/connection.nim @@ -15,6 +15,9 @@ import peerinfo, varint, vbuffer +logScope: + topic = "Connection" + const DefaultReadSize* = 1 shl 20 type @@ -22,21 +25,23 @@ type peerInfo*: PeerInfo stream*: LPStream observedAddrs*: Multiaddress + maxReadSize: int InvalidVarintException = object of LPStreamError InvalidVarintSizeException = object of LPStreamError proc newInvalidVarintException*(): ref InvalidVarintException = - newException(InvalidVarintException, "unable to parse varint") + newException(InvalidVarintException, "Unable to parse varint") proc newInvalidVarintSizeException*(): ref InvalidVarintSizeException = - newException(InvalidVarintSizeException, "wrong varint size") + newException(InvalidVarintSizeException, "Wrong varint size") -proc init*[T: Connection](self: var T, stream: LPStream) = +proc init*[T: Connection](self: var T, stream: LPStream, maxReadSize = DefaultReadSize) = ## create a new Connection for the specified async reader/writer new self self.stream = stream self.closeEvent = newAsyncEvent() + self.maxReadSize = maxReadSize # bind stream's close event to connection's close # to ensure correct close propagation @@ -45,7 +50,7 @@ proc init*[T: Connection](self: var T, stream: LPStream) = self.stream.closeEvent.wait(). addCallback do (udata: pointer): if not this.closed: - trace "closing this connection because wrapped stream closed" + trace "wrapped stream closed, closing conn" asyncCheck this.close() proc newConnection*(stream: LPStream): Connection = @@ -128,7 +133,7 @@ proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} = break if res != VarintStatus.Success: raise newInvalidVarintException() - if size.int > DefaultReadSize: + if size.int > s.maxReadSize: raise newInvalidVarintSizeException() buff.setLen(size) if size > 0.uint: diff --git a/libp2p/multistream.nim b/libp2p/multistream.nim index a03ee0e36..2910f10d6 100644 --- a/libp2p/multistream.nim +++ b/libp2p/multistream.nim @@ -154,7 +154,7 @@ proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} = warn "no handlers for ", protocol = ms await conn.write(m.na) except CatchableError as exc: - trace "exception occurred in MultistreamSelect.handle", exc = exc.msg + trace "Exception occurred", exc = exc.msg finally: trace "leaving multistream loop" diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index 4b0fe4339..6ffed67d8 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -65,15 +65,22 @@ proc newChannel*(id: uint64, proc closeMessage(s: LPChannel) {.async.} = await s.conn.writeMsg(s.id, s.closeCode) # write header -proc closedByRemote*(s: LPChannel) {.async.} = - s.closedRemote = true - proc cleanUp*(s: LPChannel): Future[void] = # method which calls the underlying buffer's `close` # method used instead of `close` since it's overloaded to # simulate half-closed streams result = procCall close(BufferStream(s)) +proc tryCleanup(s: LPChannel) {.async, inline.} = + # if stream is EOF, then cleanup immediatelly + if s.closedRemote and s.len == 0: + await s.cleanUp() + +proc closedByRemote*(s: LPChannel) {.async.} = + s.closedRemote = true + if s.len == 0: + await s.cleanUp() + proc open*(s: LPChannel): Future[void] = s.isOpen = true s.conn.writeMsg(s.id, MessageType.New, s.name) @@ -88,11 +95,13 @@ proc resetMessage(s: LPChannel) {.async.} = proc resetByRemote*(s: LPChannel) {.async.} = await allFutures(s.close(), s.closedByRemote()) s.isReset = true + await s.cleanUp() proc reset*(s: LPChannel) {.async.} = await allFutures(s.resetMessage(), s.resetByRemote()) method closed*(s: LPChannel): bool = + trace "closing lpchannel", id = s.id, initiator = s.initiator result = s.closedRemote and s.len == 0 proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] = @@ -107,57 +116,46 @@ proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] = result = procCall pushTo(BufferStream(s), data) -method read*(s: LPChannel, n = -1): Future[seq[byte]] = +template raiseEOF(): untyped = if s.closed or s.isReset: - var retFuture = newFuture[seq[byte]]("LPChannel.read") - retFuture.fail(newLPStreamEOFError()) - return retFuture + raise newLPStreamEOFError() - result = procCall read(BufferStream(s), n) +method read*(s: LPChannel, n = -1): Future[seq[byte]] {.async.} = + raiseEOF() + result = (await procCall(read(BufferStream(s), n))) + await s.tryCleanup() method readExactly*(s: LPChannel, pbytes: pointer, nbytes: int): - Future[void] = - if s.closed or s.isReset: - var retFuture = newFuture[void]("LPChannel.readExactly") - retFuture.fail(newLPStreamEOFError()) - return retFuture - - result = procCall readExactly(BufferStream(s), pbytes, nbytes) + Future[void] {.async.} = + raiseEOF() + await procCall readExactly(BufferStream(s), pbytes, nbytes) + await s.tryCleanup() method readLine*(s: LPChannel, limit = 0, sep = "\r\n"): - Future[string] = - if s.closed or s.isReset: - var retFuture = newFuture[string]("LPChannel.readLine") - retFuture.fail(newLPStreamEOFError()) - return retFuture - - result = procCall readLine(BufferStream(s), limit, sep) + Future[string] {.async.} = + raiseEOF() + result = await procCall readLine(BufferStream(s), limit, sep) + await s.tryCleanup() method readOnce*(s: LPChannel, pbytes: pointer, nbytes: int): - Future[int] = - if s.closed or s.isReset: - var retFuture = newFuture[int]("LPChannel.readOnce") - retFuture.fail(newLPStreamEOFError()) - return retFuture - - result = procCall readOnce(BufferStream(s), pbytes, nbytes) + Future[int] {.async.} = + raiseEOF() + result = await procCall readOnce(BufferStream(s), pbytes, nbytes) + await s.tryCleanup() method readUntil*(s: LPChannel, pbytes: pointer, nbytes: int, sep: seq[byte]): - Future[int] = - if s.closed or s.isReset: - var retFuture = newFuture[int]("LPChannel.readUntil") - retFuture.fail(newLPStreamEOFError()) - return retFuture - - result = procCall readOnce(BufferStream(s), pbytes, nbytes) + Future[int] {.async.} = + raiseEOF() + result = await procCall readOnce(BufferStream(s), pbytes, nbytes) + await s.tryCleanup() template writePrefix: untyped = if s.isLazy and not s.isOpen: diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index dd801f236..5e758ce8a 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -123,7 +123,7 @@ method handle*(m: Mplex) {.async, gcsafe.} = m.getChannelList(initiator).del(id) break except CatchableError as exc: - trace "exception occurred", exception = exc.msg + trace "Exception occurred", exception = exc.msg finally: trace "stopping mplex main loop" if not m.connection.closed(): diff --git a/libp2p/protocols/identify.nim b/libp2p/protocols/identify.nim index a06e99d0a..7eae82146 100644 --- a/libp2p/protocols/identify.nim +++ b/libp2p/protocols/identify.nim @@ -19,7 +19,7 @@ import ../protobuf/minprotobuf, ../utility logScope: - topic = "identify" + topic = "Identify" const IdentifyCodec* = "/ipfs/id/1.0.0" @@ -123,6 +123,7 @@ method init*(p: Identify) = proc identify*(p: Identify, conn: Connection, remotePeerInfo: PeerInfo): Future[IdentifyInfo] {.async, gcsafe.} = + trace "initiating identify" var message = await conn.readLp() if len(message) == 0: trace "identify: Invalid or empty message received!" diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index dc7a99617..438a07080 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -63,7 +63,7 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} = await p.handler(p, @[msg]) p.recvdRpcCache.put($hexData.hash) except CatchableError as exc: - error "exception occurred in PubSubPeer.handle", exc = exc.msg + trace "Exception occurred in PubSubPeer.handle", exc = exc.msg finally: trace "exiting pubsub peer read loop", peer = p.id if not conn.closed(): @@ -101,7 +101,7 @@ proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} = encoded = encodedHex except CatchableError as exc: - trace "exception occurred in PubSubPeer.send", exc = exc.msg + trace "Exception occurred in PubSubPeer.send", exc = exc.msg proc sendMsg*(p: PubSubPeer, peerId: PeerID, diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index 12114a2ca..5c6815849 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -42,7 +42,7 @@ proc readLoop(sconn: SecureConn, stream: BufferStream) {.async.} = await stream.pushTo(msg) except CatchableError as exc: - trace "exception occurred Secure.readLoop", exc = exc.msg + trace "Exception occurred Secure.readLoop", exc = exc.msg finally: if not sconn.closed: await sconn.close() @@ -63,7 +63,8 @@ proc handleConn*(s: Secure, conn: Connection, initiator: bool = false): Future[C if not isNil(sconn) and not sconn.closed: asyncCheck sconn.close() - result.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get()) + if not isNil(sconn.peerInfo) and sconn.peerInfo.publicKey.isSome: + result.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get()) method init*(s: Secure) {.gcsafe.} = proc handle(conn: Connection, proto: string) {.async, gcsafe.} = diff --git a/libp2p/stream/bufferstream.nim b/libp2p/stream/bufferstream.nim index 14daaf829..43425e847 100644 --- a/libp2p/stream/bufferstream.nim +++ b/libp2p/stream/bufferstream.nim @@ -31,7 +31,7 @@ ## buffer goes below ``maxSize`` or more data becomes available. import deques, math -import chronos +import chronos, chronicles import ../stream/lpstream const DefaultBufferSize* = 1024 @@ -154,7 +154,12 @@ method readExactly*(s: BufferStream, ## If EOF is received and ``nbytes`` is not yet read, the procedure ## will raise ``LPStreamIncompleteError``. ## - var buff = await s.read(nbytes) + var buff: seq[byte] + try: + buff = await s.read(nbytes) + except LPStreamEOFError as exc: + trace "Exception occured", exc = exc.msg + if nbytes > buff.len(): raise newLPStreamIncompleteError() @@ -362,9 +367,10 @@ proc `|`*(s: BufferStream, target: BufferStream): BufferStream = method close*(s: BufferStream) {.async.} = ## close the stream and clear the buffer + trace "closing bufferstream" for r in s.readReqs: if not(isNil(r)) and not(r.finished()): - r.cancel() + r.fail(newLPStreamEOFError()) s.dataReadEvent.fire() s.readBuf.clear() s.closeEvent.fire() diff --git a/libp2p/stream/chronosstream.nim b/libp2p/stream/chronosstream.nim index 75b124102..433fa1a20 100644 --- a/libp2p/stream/chronosstream.nim +++ b/libp2p/stream/chronosstream.nim @@ -28,16 +28,24 @@ proc newChronosStream*(server: StreamServer, result.writer = newAsyncStreamWriter(client) result.closeEvent = newAsyncEvent() +template withExceptions(body: untyped) = + try: + body + except TransportIncompleteError: + raise newLPStreamIncompleteError() + except TransportLimitError: + raise newLPStreamLimitError() + except TransportError as exc: + raise newLPStreamIncorrectError(exc.msg) + except AsyncStreamIncompleteError: + raise newLPStreamIncompleteError() + method read*(s: ChronosStream, n = -1): Future[seq[byte]] {.async.} = if s.reader.atEof: raise newLPStreamEOFError() - try: + withExceptions: result = await s.reader.read(n) - except AsyncStreamReadError as exc: - raise newLPStreamReadError(exc.par) - except AsyncStreamIncorrectError as exc: - raise newLPStreamIncorrectError(exc.msg) method readExactly*(s: ChronosStream, pbytes: pointer, @@ -45,36 +53,22 @@ method readExactly*(s: ChronosStream, if s.reader.atEof: raise newLPStreamEOFError() - try: + withExceptions: await s.reader.readExactly(pbytes, nbytes) - except AsyncStreamIncompleteError: - raise newLPStreamIncompleteError() - except AsyncStreamReadError as exc: - raise newLPStreamReadError(exc.par) - except AsyncStreamIncorrectError as exc: - raise newLPStreamIncorrectError(exc.msg) method readLine*(s: ChronosStream, limit = 0, sep = "\r\n"): Future[string] {.async.} = if s.reader.atEof: raise newLPStreamEOFError() - try: + withExceptions: result = await s.reader.readLine(limit, sep) - except AsyncStreamReadError as exc: - raise newLPStreamReadError(exc.par) - except AsyncStreamIncorrectError as exc: - raise newLPStreamIncorrectError(exc.msg) method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} = if s.reader.atEof: raise newLPStreamEOFError() - try: + withExceptions: result = await s.reader.readOnce(pbytes, nbytes) - except AsyncStreamReadError as exc: - raise newLPStreamReadError(exc.par) - except AsyncStreamIncorrectError as exc: - raise newLPStreamIncorrectError(exc.msg) method readUntil*(s: ChronosStream, pbytes: pointer, @@ -83,55 +77,29 @@ method readUntil*(s: ChronosStream, if s.reader.atEof: raise newLPStreamEOFError() - try: + withExceptions: result = await s.reader.readUntil(pbytes, nbytes, sep) - except AsyncStreamIncompleteError: - raise newLPStreamIncompleteError() - except AsyncStreamLimitError: - raise newLPStreamLimitError() - except LPStreamReadError as exc: - raise newLPStreamReadError(exc.par) - except AsyncStreamIncorrectError as exc: - raise newLPStreamIncorrectError(exc.msg) method write*(s: ChronosStream, pbytes: pointer, nbytes: int) {.async.} = if s.writer.atEof: raise newLPStreamEOFError() - try: + withExceptions: await s.writer.write(pbytes, nbytes) - except AsyncStreamWriteError as exc: - raise newLPStreamWriteError(exc.par) - except AsyncStreamIncompleteError: - raise newLPStreamIncompleteError() - except AsyncStreamIncorrectError as exc: - raise newLPStreamIncorrectError(exc.msg) method write*(s: ChronosStream, msg: string, msglen = -1) {.async.} = if s.writer.atEof: raise newLPStreamEOFError() - try: + withExceptions: await s.writer.write(msg, msglen) - except AsyncStreamWriteError as exc: - raise newLPStreamWriteError(exc.par) - except AsyncStreamIncompleteError: - raise newLPStreamIncompleteError() - except AsyncStreamIncorrectError as exc: - raise newLPStreamIncorrectError(exc.msg) method write*(s: ChronosStream, msg: seq[byte], msglen = -1) {.async.} = if s.writer.atEof: raise newLPStreamEOFError() - try: + withExceptions: await s.writer.write(msg, msglen) - except AsyncStreamWriteError as exc: - raise newLPStreamWriteError(exc.par) - except AsyncStreamIncompleteError: - raise newLPStreamIncompleteError() - except AsyncStreamIncorrectError as exc: - raise newLPStreamIncorrectError(exc.msg) method closed*(s: ChronosStream): bool {.inline.} = # TODO: we might only need to check for reader's EOF diff --git a/libp2p/stream/lpstream.nim b/libp2p/stream/lpstream.nim index 95b178ac0..455190b88 100644 --- a/libp2p/stream/lpstream.nim +++ b/libp2p/stream/lpstream.nim @@ -51,24 +51,36 @@ proc newLPStreamEOFError*(): ref Exception {.inline.} = method closed*(s: LPStream): bool {.base, inline.} = s.isClosed -method read*(s: LPStream, n = -1): Future[seq[byte]] {.base, async.} = +method read*(s: LPStream, + n = -1): + Future[seq[byte]] {.base, async.} = doAssert(false, "not implemented!") -method readExactly*(s: LPStream, pbytes: pointer, - nbytes: int): Future[void] {.base, async.} = +method readExactly*(s: LPStream, + pbytes: pointer, + nbytes: int): + Future[void] {.base, async.} = doAssert(false, "not implemented!") -method readLine*(s: LPStream, limit = 0, sep = "\r\n"): Future[string] +method readLine*(s: LPStream, + limit = 0, + sep = "\r\n"): + Future[string] {.base, async.} = doAssert(false, "not implemented!") -method readOnce*(s: LPStream, pbytes: pointer, nbytes: int): Future[int] +method readOnce*(s: LPStream, + pbytes: pointer, + nbytes: int): + Future[int] {.base, async.} = doAssert(false, "not implemented!") method readUntil*(s: LPStream, - pbytes: pointer, nbytes: int, - sep: seq[byte]): Future[int] + pbytes: pointer, + nbytes: int, + sep: seq[byte]): + Future[int] {.base, async.} = doAssert(false, "not implemented!") diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 7b492c85a..d6715d0fe 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -285,7 +285,7 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = try: await s.upgradeIncoming(conn) # perform upgrade on incoming connection except CatchableError as exc: - trace "exception occurred in Switch.start", exc = exc.msg + trace "Exception occurred in Switch.start", exc = exc.msg finally: if not isNil(conn) and not conn.closed: await conn.close() diff --git a/libp2p/transports/transport.nim b/libp2p/transports/transport.nim index 64017b708..8d5a2c4cc 100644 --- a/libp2p/transports/transport.nim +++ b/libp2p/transports/transport.nim @@ -62,7 +62,7 @@ method upgrade*(t: Transport) {.base, async, gcsafe.} = method handles*(t: Transport, address: MultiAddress): bool {.base, gcsafe.} = ## check if transport supportes the multiaddress - # by default we skip circuit addresses to avoid + # by default we skip circuit addresses to avoid # having to repeat the check in every transport address.protocols.filterIt( it == multiCodec("p2p-circuit") ).len == 0 diff --git a/tests/testmplex.nim b/tests/testmplex.nim index d8f48bae9..424d20298 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -193,7 +193,12 @@ suite "Mplex": proc handleMplexListen(stream: Connection) {.async, gcsafe.} = defer: await stream.close() - let msg = await stream.readLp() + + try: + discard await stream.readLp() + except CatchableError: + return + # we should not reach this anyway!! check false listenJob.complete() @@ -217,12 +222,14 @@ suite "Mplex": var bigseq = newSeqOfCap[uint8](MaxMsgSize + 1) for _ in 0..