Don't allow concurrent pushdata (#444)
* handle resets properly with/without pushes/reads * add clarifying comments * pushEof should also not be concurrent * move channel reset to bufferstream this is where the action happens - lpchannel merely redefines how close is done Co-authored-by: Jacek Sieka <jacek@status.im>
This commit is contained in:
parent
c42009d56e
commit
1d16d22f5f
|
@ -80,32 +80,11 @@ proc reset*(s: LPChannel) {.async, gcsafe.} =
|
|||
return
|
||||
|
||||
s.isClosed = true
|
||||
s.closedLocal = true
|
||||
|
||||
trace "Resetting channel", s, len = s.len
|
||||
|
||||
# First, make sure any new calls to `readOnce` and `pushData` etc will fail -
|
||||
# there may already be such calls in the event queue however
|
||||
s.closedLocal = true
|
||||
s.isEof = true
|
||||
s.readBuf = StreamSeq()
|
||||
s.pushedEof = true
|
||||
|
||||
let pushing = s.pushing # s.pushing changes while iterating
|
||||
for i in 0..<pushing:
|
||||
# Make sure to drain any ongoing pushes - there's already at least one item
|
||||
# more in the queue already so any ongoing reads shouldn't interfere
|
||||
# Notably, popFirst is not fair - which reader/writer gets woken up depends
|
||||
discard await s.readQueue.popFirst()
|
||||
|
||||
if s.readQueue.len == 0 and s.reading:
|
||||
# There is an active reader - we just grabbed all pushes so we need to push
|
||||
# an EOF marker to wake it up
|
||||
try:
|
||||
s.readQueue.addLastNoWait(@[])
|
||||
except CatchableError:
|
||||
raiseAssert "We just checked the queue is empty"
|
||||
|
||||
if not s.conn.isClosed:
|
||||
if s.isOpen and not s.conn.isClosed:
|
||||
# If the connection is still active, notify the other end
|
||||
proc resetMessage() {.async.} =
|
||||
try:
|
||||
|
@ -117,7 +96,6 @@ proc reset*(s: LPChannel) {.async, gcsafe.} =
|
|||
|
||||
asyncSpawn resetMessage()
|
||||
|
||||
# This should wake up any readers by pushing an EOF marker at least
|
||||
await s.closeImpl() # noraises, nocancels
|
||||
|
||||
trace "Channel reset", s
|
||||
|
@ -133,7 +111,7 @@ method close*(s: LPChannel) {.async, gcsafe.} =
|
|||
|
||||
trace "Closing channel", s, conn = s.conn, len = s.len
|
||||
|
||||
if s.isOpen:
|
||||
if s.isOpen and not s.conn.isClosed:
|
||||
try:
|
||||
await s.conn.writeMsg(s.id, s.closeCode) # write close
|
||||
except CancelledError as exc:
|
||||
|
|
|
@ -28,7 +28,7 @@ type
|
|||
BufferStream* = ref object of Connection
|
||||
readQueue*: AsyncQueue[seq[byte]] # read queue for managing backpressure
|
||||
readBuf*: StreamSeq # overflow buffer for readOnce
|
||||
pushing*: int # number of ongoing push operations
|
||||
pushing*: bool # number of ongoing push operations
|
||||
reading*: bool # is there an ongoing read? (only allow one)
|
||||
pushedEof*: bool # eof marker has been put on readQueue
|
||||
returnedEof*: bool # 0-byte readOnce has been completed
|
||||
|
@ -63,6 +63,8 @@ method pushData*(s: BufferStream, data: seq[byte]) {.base, async.} =
|
|||
##
|
||||
## `pushTo` will block if the queue is full, thus maintaining backpressure.
|
||||
##
|
||||
|
||||
doAssert(not s.pushing, "Only one concurrent push allowed")
|
||||
if s.isClosed or s.pushedEof:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
|
@ -71,26 +73,29 @@ method pushData*(s: BufferStream, data: seq[byte]) {.base, async.} =
|
|||
|
||||
# We will block here if there is already data queued, until it has been
|
||||
# processed
|
||||
inc s.pushing
|
||||
try:
|
||||
s.pushing = true
|
||||
trace "Pushing data", s, data = data.len
|
||||
await s.readQueue.addLast(data)
|
||||
finally:
|
||||
dec s.pushing
|
||||
s.pushing = false
|
||||
|
||||
method pushEof*(s: BufferStream) {.base, async.} =
|
||||
if s.pushedEof:
|
||||
return
|
||||
|
||||
doAssert(not s.pushing, "Only one concurrent push allowed")
|
||||
|
||||
s.pushedEof = true
|
||||
|
||||
# We will block here if there is already data queued, until it has been
|
||||
# processed
|
||||
inc s.pushing
|
||||
try:
|
||||
s.pushing = true
|
||||
trace "Pushing EOF", s
|
||||
await s.readQueue.addLast(@[])
|
||||
await s.readQueue.addLast(Eof)
|
||||
finally:
|
||||
dec s.pushing
|
||||
s.pushing = false
|
||||
|
||||
method atEof*(s: BufferStream): bool =
|
||||
s.isEof and s.readBuf.len == 0
|
||||
|
@ -159,8 +164,36 @@ method closeImpl*(s: BufferStream): Future[void] =
|
|||
## close the stream and clear the buffer
|
||||
trace "Closing BufferStream", s, len = s.len
|
||||
|
||||
if not s.pushedEof: # Potentially wake up reader
|
||||
asyncSpawn s.pushEof()
|
||||
# First, make sure any new calls to `readOnce` and `pushData` etc will fail -
|
||||
# there may already be such calls in the event queue however
|
||||
s.isEof = true
|
||||
s.readBuf = StreamSeq()
|
||||
s.pushedEof = true
|
||||
|
||||
# Essentially we need to handle the following cases
|
||||
#
|
||||
# - If a push was in progress but no reader is
|
||||
# attached we need to pop the queue
|
||||
# - If a read was in progress without without a
|
||||
# push/data we need to push the Eof marker to
|
||||
# notify the reader that the channel closed
|
||||
#
|
||||
# In all other cases, there should be a data to complete
|
||||
# a read or enough room in the queue/buffer to complete a
|
||||
# push.
|
||||
#
|
||||
# State | Q Empty | Q Full
|
||||
# ------------|----------|-------
|
||||
# Reading | Push Eof | Na
|
||||
# Pushing | Na | Pop
|
||||
if not(s.reading and s.pushing):
|
||||
if s.reading:
|
||||
if s.readQueue.empty():
|
||||
# There is an active reader
|
||||
s.readQueue.addLastNoWait(Eof)
|
||||
elif s.pushing:
|
||||
if not s.readQueue.empty():
|
||||
discard s.readQueue.popFirstNoWait()
|
||||
|
||||
trace "Closed BufferStream", s
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ logScope:
|
|||
|
||||
const
|
||||
LPStreamTrackerName* = "LPStream"
|
||||
Eof* = @[]
|
||||
|
||||
type
|
||||
Direction* {.pure.} = enum
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import unittest
|
||||
import chronos, stew/byteutils
|
||||
import ../libp2p/stream/bufferstream,
|
||||
../libp2p/stream/lpstream
|
||||
../libp2p/stream/lpstream,
|
||||
../libp2p/errors
|
||||
|
||||
import ./helpers
|
||||
|
||||
|
@ -87,10 +88,12 @@ suite "BufferStream":
|
|||
let buff = newBufferStream()
|
||||
check buff.len == 0
|
||||
|
||||
let w1 = buff.pushData("Msg 1".toBytes())
|
||||
let w2 = buff.pushData("Msg 2".toBytes())
|
||||
let w3 = buff.pushData("Msg 3".toBytes())
|
||||
proc writer1() {.async.} =
|
||||
await buff.pushData("Msg 1".toBytes())
|
||||
await buff.pushData("Msg 2".toBytes())
|
||||
await buff.pushData("Msg 3".toBytes())
|
||||
|
||||
let writerFut1 = writer1()
|
||||
var data: array[5, byte]
|
||||
await buff.readExactly(addr data[0], data.len)
|
||||
|
||||
|
@ -102,13 +105,14 @@ suite "BufferStream":
|
|||
await buff.readExactly(addr data[0], data.len)
|
||||
check string.fromBytes(data) == "Msg 3"
|
||||
|
||||
for f in [w1, w2, w3]: await f
|
||||
await writerFut1
|
||||
|
||||
let w4 = buff.pushData("Msg 4".toBytes())
|
||||
let w5 = buff.pushData("Msg 5".toBytes())
|
||||
let w6 = buff.pushData("Msg 6".toBytes())
|
||||
proc writer2() {.async.} =
|
||||
await buff.pushData("Msg 4".toBytes())
|
||||
await buff.pushData("Msg 5".toBytes())
|
||||
await buff.pushData("Msg 6".toBytes())
|
||||
|
||||
await buff.close()
|
||||
let writerFut2 = writer2()
|
||||
|
||||
await buff.readExactly(addr data[0], data.len)
|
||||
check string.fromBytes(data) == "Msg 4"
|
||||
|
@ -118,27 +122,33 @@ suite "BufferStream":
|
|||
|
||||
await buff.readExactly(addr data[0], data.len)
|
||||
check string.fromBytes(data) == "Msg 6"
|
||||
for f in [w4, w5, w6]: await f
|
||||
|
||||
await buff.close()
|
||||
await writerFut2
|
||||
|
||||
asyncTest "small reads":
|
||||
let buff = newBufferStream()
|
||||
check buff.len == 0
|
||||
|
||||
var writes: seq[Future[void]]
|
||||
var str: string
|
||||
proc writer() {.async.} =
|
||||
for i in 0..<10:
|
||||
writes.add buff.pushData("123".toBytes())
|
||||
await buff.pushData("123".toBytes())
|
||||
str &= "123"
|
||||
await buff.close() # all data should still be read after close
|
||||
|
||||
var str2: string
|
||||
|
||||
proc reader() {.async.} =
|
||||
var data: array[2, byte]
|
||||
expect LPStreamEOFError:
|
||||
while true:
|
||||
let x = await buff.readOnce(addr data[0], data.len)
|
||||
str2 &= string.fromBytes(data[0..<x])
|
||||
|
||||
for f in writes: await f
|
||||
|
||||
await allFuturesThrowing(
|
||||
allFinished(reader(), writer()))
|
||||
check str == str2
|
||||
await buff.close()
|
||||
|
||||
|
@ -196,9 +206,11 @@ suite "BufferStream":
|
|||
fut = stream.pushData(toBytes("hello"))
|
||||
fut2 = stream.pushData(toBytes("again"))
|
||||
await stream.close()
|
||||
expect AsyncTimeoutError:
|
||||
await wait(fut, 100.milliseconds)
|
||||
await wait(fut2, 100.milliseconds)
|
||||
|
||||
# Both writes should be completed on close (technically, the should maybe
|
||||
# be cancelled, at least the second one...
|
||||
check await fut.withTimeout(100.milliseconds)
|
||||
check await fut2.withTimeout(100.milliseconds)
|
||||
|
||||
await stream.close()
|
||||
|
||||
|
@ -211,3 +223,14 @@ suite "BufferStream":
|
|||
|
||||
expect LPStreamEOFError:
|
||||
await stream.pushData("123".toBytes())
|
||||
|
||||
asyncTest "no concurrent pushes":
|
||||
var stream = newBufferStream()
|
||||
await stream.pushData("123".toBytes())
|
||||
let push = stream.pushData("123".toBytes())
|
||||
|
||||
expect AssertionError:
|
||||
await stream.pushData("123".toBytes())
|
||||
|
||||
await stream.closeWithEOF()
|
||||
await push
|
||||
|
|
|
@ -215,16 +215,17 @@ suite "Mplex":
|
|||
conn = newBufferStream(writeHandler)
|
||||
chann = LPChannel.init(1, conn, true)
|
||||
|
||||
let futs = @[
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
]
|
||||
proc pushes() {.async.} = # pushes don't hang on reset
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
|
||||
let push = pushes()
|
||||
await chann.reset()
|
||||
check await allFutures(futs).withTimeout(100.millis)
|
||||
check await allFutures(push).withTimeout(100.millis)
|
||||
await conn.close()
|
||||
|
||||
asyncTest "reset should complete both read and push":
|
||||
|
@ -249,23 +250,22 @@ suite "Mplex":
|
|||
chann = LPChannel.init(1, conn, true)
|
||||
|
||||
var data = newSeq[byte](1)
|
||||
let futs = [
|
||||
chann.readExactly(addr data[0], 1),
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
chann.pushData(@[0'u8]),
|
||||
]
|
||||
let read = chann.readExactly(addr data[0], 1)
|
||||
proc pushes() {.async.} =
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
|
||||
await chann.reset()
|
||||
check await allFutures(futs).withTimeout(100.millis)
|
||||
await futs[0]
|
||||
check await allFutures(read, pushes()).withTimeout(100.millis)
|
||||
await conn.close()
|
||||
|
||||
asyncTest "reset should complete both read and push with cancel":
|
||||
|
@ -300,6 +300,59 @@ suite "Mplex":
|
|||
check await allFutures(rfut, rfut2, wfut, wfut2).withTimeout(100.millis)
|
||||
await conn.close()
|
||||
|
||||
asyncTest "reset should complete ongoing push without reader":
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let
|
||||
conn = newBufferStream(writeHandler)
|
||||
chann = LPChannel.init(1, conn, true)
|
||||
|
||||
await chann.pushData(@[0'u8])
|
||||
let push1 = chann.pushData(@[0'u8])
|
||||
await chann.reset()
|
||||
check await allFutures(push1).withTimeout(100.millis)
|
||||
await conn.close()
|
||||
|
||||
asyncTest "reset should complete ongoing read without a push":
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let
|
||||
conn = newBufferStream(writeHandler)
|
||||
chann = LPChannel.init(1, conn, true)
|
||||
|
||||
var data = newSeq[byte](1)
|
||||
let rfut = chann.readExactly(addr data[0], 1)
|
||||
await chann.reset()
|
||||
check await allFutures(rfut).withTimeout(100.millis)
|
||||
await conn.close()
|
||||
|
||||
asyncTest "reset should allow all reads and pushes to complete":
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let
|
||||
conn = newBufferStream(writeHandler)
|
||||
chann = LPChannel.init(1, conn, true)
|
||||
|
||||
var data = newSeq[byte](1)
|
||||
proc writer() {.async.} =
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
await chann.pushData(@[0'u8])
|
||||
|
||||
proc reader() {.async.} =
|
||||
await chann.readExactly(addr data[0], 1)
|
||||
await chann.readExactly(addr data[0], 1)
|
||||
await chann.readExactly(addr data[0], 1)
|
||||
|
||||
let rw = @[writer(), reader()]
|
||||
|
||||
await chann.close()
|
||||
check await chann.reset() # this would hang
|
||||
.withTimeout(100.millis)
|
||||
|
||||
check await allFuturesThrowing(
|
||||
allFinished(rw))
|
||||
.withTimeout(100.millis)
|
||||
|
||||
await conn.close()
|
||||
|
||||
asyncTest "channel should fail writing":
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let
|
||||
|
|
Loading…
Reference in New Issue