Don't allow concurrent pushdata (#444)

* handle resets properly with/without pushes/reads

* add clarifying comments

* pushEof should also not be concurrent

* move channel reset to bufferstream

this is where the action happens - lpchannel merely redefines how close
is done

Co-authored-by: Jacek Sieka <jacek@status.im>
This commit is contained in:
Dmitriy Ryajov 2020-11-23 09:07:11 -06:00 committed by GitHub
parent c42009d56e
commit 1d16d22f5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 170 additions and 82 deletions

View File

@ -80,32 +80,11 @@ proc reset*(s: LPChannel) {.async, gcsafe.} =
return return
s.isClosed = true s.isClosed = true
s.closedLocal = true
trace "Resetting channel", s, len = s.len trace "Resetting channel", s, len = s.len
# First, make sure any new calls to `readOnce` and `pushData` etc will fail - if s.isOpen and not s.conn.isClosed:
# there may already be such calls in the event queue however
s.closedLocal = true
s.isEof = true
s.readBuf = StreamSeq()
s.pushedEof = true
let pushing = s.pushing # s.pushing changes while iterating
for i in 0..<pushing:
# Make sure to drain any ongoing pushes - there's already at least one item
# more in the queue already so any ongoing reads shouldn't interfere
# Notably, popFirst is not fair - which reader/writer gets woken up depends
discard await s.readQueue.popFirst()
if s.readQueue.len == 0 and s.reading:
# There is an active reader - we just grabbed all pushes so we need to push
# an EOF marker to wake it up
try:
s.readQueue.addLastNoWait(@[])
except CatchableError:
raiseAssert "We just checked the queue is empty"
if not s.conn.isClosed:
# If the connection is still active, notify the other end # If the connection is still active, notify the other end
proc resetMessage() {.async.} = proc resetMessage() {.async.} =
try: try:
@ -117,7 +96,6 @@ proc reset*(s: LPChannel) {.async, gcsafe.} =
asyncSpawn resetMessage() asyncSpawn resetMessage()
# This should wake up any readers by pushing an EOF marker at least
await s.closeImpl() # noraises, nocancels await s.closeImpl() # noraises, nocancels
trace "Channel reset", s trace "Channel reset", s
@ -133,7 +111,7 @@ method close*(s: LPChannel) {.async, gcsafe.} =
trace "Closing channel", s, conn = s.conn, len = s.len trace "Closing channel", s, conn = s.conn, len = s.len
if s.isOpen: if s.isOpen and not s.conn.isClosed:
try: try:
await s.conn.writeMsg(s.id, s.closeCode) # write close await s.conn.writeMsg(s.id, s.closeCode) # write close
except CancelledError as exc: except CancelledError as exc:

View File

@ -28,7 +28,7 @@ type
BufferStream* = ref object of Connection BufferStream* = ref object of Connection
readQueue*: AsyncQueue[seq[byte]] # read queue for managing backpressure readQueue*: AsyncQueue[seq[byte]] # read queue for managing backpressure
readBuf*: StreamSeq # overflow buffer for readOnce readBuf*: StreamSeq # overflow buffer for readOnce
pushing*: int # number of ongoing push operations pushing*: bool # number of ongoing push operations
reading*: bool # is there an ongoing read? (only allow one) reading*: bool # is there an ongoing read? (only allow one)
pushedEof*: bool # eof marker has been put on readQueue pushedEof*: bool # eof marker has been put on readQueue
returnedEof*: bool # 0-byte readOnce has been completed returnedEof*: bool # 0-byte readOnce has been completed
@ -63,6 +63,8 @@ method pushData*(s: BufferStream, data: seq[byte]) {.base, async.} =
## ##
## `pushTo` will block if the queue is full, thus maintaining backpressure. ## `pushTo` will block if the queue is full, thus maintaining backpressure.
## ##
doAssert(not s.pushing, "Only one concurrent push allowed")
if s.isClosed or s.pushedEof: if s.isClosed or s.pushedEof:
raise newLPStreamEOFError() raise newLPStreamEOFError()
@ -71,26 +73,29 @@ method pushData*(s: BufferStream, data: seq[byte]) {.base, async.} =
# We will block here if there is already data queued, until it has been # We will block here if there is already data queued, until it has been
# processed # processed
inc s.pushing
try: try:
s.pushing = true
trace "Pushing data", s, data = data.len trace "Pushing data", s, data = data.len
await s.readQueue.addLast(data) await s.readQueue.addLast(data)
finally: finally:
dec s.pushing s.pushing = false
method pushEof*(s: BufferStream) {.base, async.} = method pushEof*(s: BufferStream) {.base, async.} =
if s.pushedEof: if s.pushedEof:
return return
doAssert(not s.pushing, "Only one concurrent push allowed")
s.pushedEof = true s.pushedEof = true
# We will block here if there is already data queued, until it has been # We will block here if there is already data queued, until it has been
# processed # processed
inc s.pushing
try: try:
s.pushing = true
trace "Pushing EOF", s trace "Pushing EOF", s
await s.readQueue.addLast(@[]) await s.readQueue.addLast(Eof)
finally: finally:
dec s.pushing s.pushing = false
method atEof*(s: BufferStream): bool = method atEof*(s: BufferStream): bool =
s.isEof and s.readBuf.len == 0 s.isEof and s.readBuf.len == 0
@ -159,8 +164,36 @@ method closeImpl*(s: BufferStream): Future[void] =
## close the stream and clear the buffer ## close the stream and clear the buffer
trace "Closing BufferStream", s, len = s.len trace "Closing BufferStream", s, len = s.len
if not s.pushedEof: # Potentially wake up reader # First, make sure any new calls to `readOnce` and `pushData` etc will fail -
asyncSpawn s.pushEof() # there may already be such calls in the event queue however
s.isEof = true
s.readBuf = StreamSeq()
s.pushedEof = true
# Essentially we need to handle the following cases
#
# - If a push was in progress but no reader is
# attached we need to pop the queue
# - If a read was in progress without without a
# push/data we need to push the Eof marker to
# notify the reader that the channel closed
#
# In all other cases, there should be a data to complete
# a read or enough room in the queue/buffer to complete a
# push.
#
# State | Q Empty | Q Full
# ------------|----------|-------
# Reading | Push Eof | Na
# Pushing | Na | Pop
if not(s.reading and s.pushing):
if s.reading:
if s.readQueue.empty():
# There is an active reader
s.readQueue.addLastNoWait(Eof)
elif s.pushing:
if not s.readQueue.empty():
discard s.readQueue.popFirstNoWait()
trace "Closed BufferStream", s trace "Closed BufferStream", s

View File

@ -25,6 +25,7 @@ logScope:
const const
LPStreamTrackerName* = "LPStream" LPStreamTrackerName* = "LPStream"
Eof* = @[]
type type
Direction* {.pure.} = enum Direction* {.pure.} = enum

View File

@ -1,7 +1,8 @@
import unittest import unittest
import chronos, stew/byteutils import chronos, stew/byteutils
import ../libp2p/stream/bufferstream, import ../libp2p/stream/bufferstream,
../libp2p/stream/lpstream ../libp2p/stream/lpstream,
../libp2p/errors
import ./helpers import ./helpers
@ -87,10 +88,12 @@ suite "BufferStream":
let buff = newBufferStream() let buff = newBufferStream()
check buff.len == 0 check buff.len == 0
let w1 = buff.pushData("Msg 1".toBytes()) proc writer1() {.async.} =
let w2 = buff.pushData("Msg 2".toBytes()) await buff.pushData("Msg 1".toBytes())
let w3 = buff.pushData("Msg 3".toBytes()) await buff.pushData("Msg 2".toBytes())
await buff.pushData("Msg 3".toBytes())
let writerFut1 = writer1()
var data: array[5, byte] var data: array[5, byte]
await buff.readExactly(addr data[0], data.len) await buff.readExactly(addr data[0], data.len)
@ -102,13 +105,14 @@ suite "BufferStream":
await buff.readExactly(addr data[0], data.len) await buff.readExactly(addr data[0], data.len)
check string.fromBytes(data) == "Msg 3" check string.fromBytes(data) == "Msg 3"
for f in [w1, w2, w3]: await f await writerFut1
let w4 = buff.pushData("Msg 4".toBytes()) proc writer2() {.async.} =
let w5 = buff.pushData("Msg 5".toBytes()) await buff.pushData("Msg 4".toBytes())
let w6 = buff.pushData("Msg 6".toBytes()) await buff.pushData("Msg 5".toBytes())
await buff.pushData("Msg 6".toBytes())
await buff.close() let writerFut2 = writer2()
await buff.readExactly(addr data[0], data.len) await buff.readExactly(addr data[0], data.len)
check string.fromBytes(data) == "Msg 4" check string.fromBytes(data) == "Msg 4"
@ -118,27 +122,33 @@ suite "BufferStream":
await buff.readExactly(addr data[0], data.len) await buff.readExactly(addr data[0], data.len)
check string.fromBytes(data) == "Msg 6" check string.fromBytes(data) == "Msg 6"
for f in [w4, w5, w6]: await f
await buff.close()
await writerFut2
asyncTest "small reads": asyncTest "small reads":
let buff = newBufferStream() let buff = newBufferStream()
check buff.len == 0 check buff.len == 0
var writes: seq[Future[void]]
var str: string var str: string
proc writer() {.async.} =
for i in 0..<10: for i in 0..<10:
writes.add buff.pushData("123".toBytes()) await buff.pushData("123".toBytes())
str &= "123" str &= "123"
await buff.close() # all data should still be read after close await buff.close() # all data should still be read after close
var str2: string var str2: string
proc reader() {.async.} =
var data: array[2, byte] var data: array[2, byte]
expect LPStreamEOFError: expect LPStreamEOFError:
while true: while true:
let x = await buff.readOnce(addr data[0], data.len) let x = await buff.readOnce(addr data[0], data.len)
str2 &= string.fromBytes(data[0..<x]) str2 &= string.fromBytes(data[0..<x])
for f in writes: await f
await allFuturesThrowing(
allFinished(reader(), writer()))
check str == str2 check str == str2
await buff.close() await buff.close()
@ -196,9 +206,11 @@ suite "BufferStream":
fut = stream.pushData(toBytes("hello")) fut = stream.pushData(toBytes("hello"))
fut2 = stream.pushData(toBytes("again")) fut2 = stream.pushData(toBytes("again"))
await stream.close() await stream.close()
expect AsyncTimeoutError:
await wait(fut, 100.milliseconds) # Both writes should be completed on close (technically, the should maybe
await wait(fut2, 100.milliseconds) # be cancelled, at least the second one...
check await fut.withTimeout(100.milliseconds)
check await fut2.withTimeout(100.milliseconds)
await stream.close() await stream.close()
@ -211,3 +223,14 @@ suite "BufferStream":
expect LPStreamEOFError: expect LPStreamEOFError:
await stream.pushData("123".toBytes()) await stream.pushData("123".toBytes())
asyncTest "no concurrent pushes":
var stream = newBufferStream()
await stream.pushData("123".toBytes())
let push = stream.pushData("123".toBytes())
expect AssertionError:
await stream.pushData("123".toBytes())
await stream.closeWithEOF()
await push

View File

@ -215,16 +215,17 @@ suite "Mplex":
conn = newBufferStream(writeHandler) conn = newBufferStream(writeHandler)
chann = LPChannel.init(1, conn, true) chann = LPChannel.init(1, conn, true)
let futs = @[ proc pushes() {.async.} = # pushes don't hang on reset
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
]
let push = pushes()
await chann.reset() await chann.reset()
check await allFutures(futs).withTimeout(100.millis) check await allFutures(push).withTimeout(100.millis)
await conn.close() await conn.close()
asyncTest "reset should complete both read and push": asyncTest "reset should complete both read and push":
@ -249,23 +250,22 @@ suite "Mplex":
chann = LPChannel.init(1, conn, true) chann = LPChannel.init(1, conn, true)
var data = newSeq[byte](1) var data = newSeq[byte](1)
let futs = [ let read = chann.readExactly(addr data[0], 1)
chann.readExactly(addr data[0], 1), proc pushes() {.async.} =
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
chann.pushData(@[0'u8]), await chann.pushData(@[0'u8])
]
await chann.reset() await chann.reset()
check await allFutures(futs).withTimeout(100.millis) check await allFutures(read, pushes()).withTimeout(100.millis)
await futs[0]
await conn.close() await conn.close()
asyncTest "reset should complete both read and push with cancel": asyncTest "reset should complete both read and push with cancel":
@ -300,6 +300,59 @@ suite "Mplex":
check await allFutures(rfut, rfut2, wfut, wfut2).withTimeout(100.millis) check await allFutures(rfut, rfut2, wfut, wfut2).withTimeout(100.millis)
await conn.close() await conn.close()
asyncTest "reset should complete ongoing push without reader":
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let
conn = newBufferStream(writeHandler)
chann = LPChannel.init(1, conn, true)
await chann.pushData(@[0'u8])
let push1 = chann.pushData(@[0'u8])
await chann.reset()
check await allFutures(push1).withTimeout(100.millis)
await conn.close()
asyncTest "reset should complete ongoing read without a push":
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let
conn = newBufferStream(writeHandler)
chann = LPChannel.init(1, conn, true)
var data = newSeq[byte](1)
let rfut = chann.readExactly(addr data[0], 1)
await chann.reset()
check await allFutures(rfut).withTimeout(100.millis)
await conn.close()
asyncTest "reset should allow all reads and pushes to complete":
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let
conn = newBufferStream(writeHandler)
chann = LPChannel.init(1, conn, true)
var data = newSeq[byte](1)
proc writer() {.async.} =
await chann.pushData(@[0'u8])
await chann.pushData(@[0'u8])
await chann.pushData(@[0'u8])
proc reader() {.async.} =
await chann.readExactly(addr data[0], 1)
await chann.readExactly(addr data[0], 1)
await chann.readExactly(addr data[0], 1)
let rw = @[writer(), reader()]
await chann.close()
check await chann.reset() # this would hang
.withTimeout(100.millis)
check await allFuturesThrowing(
allFinished(rw))
.withTimeout(100.millis)
await conn.close()
asyncTest "channel should fail writing": asyncTest "channel should fail writing":
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let let