diff --git a/libp2p/stream/bufferstream.nim b/libp2p/stream/bufferstream.nim index 891e88fd1..b15045e19 100644 --- a/libp2p/stream/bufferstream.nim +++ b/libp2p/stream/bufferstream.nim @@ -35,25 +35,26 @@ import chronos import ../stream/lpstream type - WriteHandler* = proc (data: seq[byte]): Future[void] {.gcsafe.} # TODO: figure out how to make this generic to avoid casts + # TODO: figure out how to make this generic to avoid casts + WriteHandler* = proc (data: seq[byte]): Future[void] {.gcsafe.} BufferStream* = ref object of LPStream maxSize*: int # buffer's max size in bytes readBuf: Deque[byte] # a deque is based on a ring buffer - readReqs: Deque[Future[int]] # use dequeue to fire reads in order + readReqs: Deque[Future[void]] # use dequeue to fire reads in order dataReadEvent: AsyncEvent writeHandler*: WriteHandler -proc requestReadBytes(s: BufferStream): Future[int] = +proc requestReadBytes(s: BufferStream): Future[void] = ## create a future that will complete when more ## data becomes available in the read buffer - result = newFuture[int]() + result = newFuture[void]() s.readReqs.addLast(result) proc initBufferStream*(s: BufferStream, handler: WriteHandler, size: int = 1024) = s.maxSize = if isPowerOfTwo(size): size else: nextPowerOfTwo(size) s.readBuf = initDeque[byte](s.maxSize) - s.readReqs = initDeque[Future[int]]() + s.readReqs = initDeque[Future[void]]() s.dataReadEvent = newAsyncEvent() s.writeHandler = handler @@ -90,7 +91,7 @@ proc pushTo*(s: BufferStream, data: seq[byte]) {.async, gcsafe.} = # resolve the next queued read request if s.readReqs.len > 0: - s.readReqs.popFirst().complete(index + 1) + s.readReqs.popFirst().complete() if index >= data.len: break @@ -111,7 +112,7 @@ method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async, gcsafe.} = inc(index) if index < size: - discard await s.requestReadBytes() + await s.requestReadBytes() method readExactly*(s: BufferStream, pbytes: pointer, @@ -174,7 +175,7 @@ method readOnce*(s: BufferStream, ## If internal buffer is not empty, ``nbytes`` bytes will be transferred from ## internal buffer, otherwise it will wait until some bytes will be received. if s.readBuf.len == 0: - discard await s.requestReadBytes() + await s.requestReadBytes() var len = if nbytes > s.readBuf.len: s.readBuf.len else: nbytes await s.readExactly(pbytes, len) diff --git a/libp2p/stream/chronosstream.nim b/libp2p/stream/chronosstream.nim index 0f9a12559..09f1baa64 100644 --- a/libp2p/stream/chronosstream.nim +++ b/libp2p/stream/chronosstream.nim @@ -60,9 +60,10 @@ method close*(s: ChronosStream) {.async, gcsafe.} = await s.reader.closeWait() await s.writer.finish() - if not s.writer.closed: await s.writer.closeWait() - await s.client.closeWait() + if not s.client.closed: + await s.client.closeWait() + s.closed = true \ No newline at end of file diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 56de312fe..5be050f39 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -1,4 +1,4 @@ -import unittest, sequtils, sugar +import unittest, sequtils, sugar, strformat import chronos, nimcrypto/utils import ../libp2p/connection, ../libp2p/stream/lpstream, @@ -131,6 +131,49 @@ suite "Mplex": check: waitFor(testNewStream()) == true + test "e2e - multiple streams": + proc testNewStream(): Future[bool] {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53382") + + var count = 0 + var completionFut: Future[void] = newFuture[void]() + proc connHandler(conn: Connection) {.async, gcsafe.} = + proc handleMplexListen(stream: Connection) {.async, gcsafe.} = + let msg = await stream.readLp() + check cast[string](msg) == &"Hello from stream {count}!" + count.inc + await stream.close() + if count == 11: + completionFut.complete() + + let mplexListen = newMplex(conn) + mplexListen.streamHandler = handleMplexListen + asyncCheck mplexListen.handle() + + let transport1: TcpTransport = newTransport(TcpTransport) + await transport1.listen(ma, connHandler) + + let transport2: TcpTransport = newTransport(TcpTransport) + let conn = await transport2.dial(ma) + + let mplexDial = newMplex(conn) + asyncCheck mplexDial.handle() + + for i in 0..10: + let stream = await mplexDial.newStream() + await stream.writeLp(&"Hello from stream {i}!") + + await completionFut + # closing the connection doesn't transfer all the data + # this seems to be a bug in chronos + # await conn.close() + check count == 11 + + result = true + + check: + waitFor(testNewStream()) == true + test "half closed - channel should close for write": proc testClosedForWrite(): Future[void] {.async.} = let chann = newChannel(1, newConnection(new LPStream), true)