mirror of
https://github.com/codex-storage/nim-libp2p.git
synced 2025-01-12 20:14:09 +00:00
Don't allow concurrent pushdata (#444)
* handle resets properly with/without pushes/reads * add clarifying comments * pushEof should also not be concurrent * move channel reset to bufferstream this is where the action happens - lpchannel merely redefines how close is done Co-authored-by: Jacek Sieka <jacek@status.im>
This commit is contained in:
parent
c42009d56e
commit
1d16d22f5f
@ -80,32 +80,11 @@ proc reset*(s: LPChannel) {.async, gcsafe.} =
|
|||||||
return
|
return
|
||||||
|
|
||||||
s.isClosed = true
|
s.isClosed = true
|
||||||
|
s.closedLocal = true
|
||||||
|
|
||||||
trace "Resetting channel", s, len = s.len
|
trace "Resetting channel", s, len = s.len
|
||||||
|
|
||||||
# First, make sure any new calls to `readOnce` and `pushData` etc will fail -
|
if s.isOpen and not s.conn.isClosed:
|
||||||
# there may already be such calls in the event queue however
|
|
||||||
s.closedLocal = true
|
|
||||||
s.isEof = true
|
|
||||||
s.readBuf = StreamSeq()
|
|
||||||
s.pushedEof = true
|
|
||||||
|
|
||||||
let pushing = s.pushing # s.pushing changes while iterating
|
|
||||||
for i in 0..<pushing:
|
|
||||||
# Make sure to drain any ongoing pushes - there's already at least one item
|
|
||||||
# more in the queue already so any ongoing reads shouldn't interfere
|
|
||||||
# Notably, popFirst is not fair - which reader/writer gets woken up depends
|
|
||||||
discard await s.readQueue.popFirst()
|
|
||||||
|
|
||||||
if s.readQueue.len == 0 and s.reading:
|
|
||||||
# There is an active reader - we just grabbed all pushes so we need to push
|
|
||||||
# an EOF marker to wake it up
|
|
||||||
try:
|
|
||||||
s.readQueue.addLastNoWait(@[])
|
|
||||||
except CatchableError:
|
|
||||||
raiseAssert "We just checked the queue is empty"
|
|
||||||
|
|
||||||
if not s.conn.isClosed:
|
|
||||||
# If the connection is still active, notify the other end
|
# If the connection is still active, notify the other end
|
||||||
proc resetMessage() {.async.} =
|
proc resetMessage() {.async.} =
|
||||||
try:
|
try:
|
||||||
@ -117,7 +96,6 @@ proc reset*(s: LPChannel) {.async, gcsafe.} =
|
|||||||
|
|
||||||
asyncSpawn resetMessage()
|
asyncSpawn resetMessage()
|
||||||
|
|
||||||
# This should wake up any readers by pushing an EOF marker at least
|
|
||||||
await s.closeImpl() # noraises, nocancels
|
await s.closeImpl() # noraises, nocancels
|
||||||
|
|
||||||
trace "Channel reset", s
|
trace "Channel reset", s
|
||||||
@ -133,7 +111,7 @@ method close*(s: LPChannel) {.async, gcsafe.} =
|
|||||||
|
|
||||||
trace "Closing channel", s, conn = s.conn, len = s.len
|
trace "Closing channel", s, conn = s.conn, len = s.len
|
||||||
|
|
||||||
if s.isOpen:
|
if s.isOpen and not s.conn.isClosed:
|
||||||
try:
|
try:
|
||||||
await s.conn.writeMsg(s.id, s.closeCode) # write close
|
await s.conn.writeMsg(s.id, s.closeCode) # write close
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
|
@ -28,7 +28,7 @@ type
|
|||||||
BufferStream* = ref object of Connection
|
BufferStream* = ref object of Connection
|
||||||
readQueue*: AsyncQueue[seq[byte]] # read queue for managing backpressure
|
readQueue*: AsyncQueue[seq[byte]] # read queue for managing backpressure
|
||||||
readBuf*: StreamSeq # overflow buffer for readOnce
|
readBuf*: StreamSeq # overflow buffer for readOnce
|
||||||
pushing*: int # number of ongoing push operations
|
pushing*: bool # number of ongoing push operations
|
||||||
reading*: bool # is there an ongoing read? (only allow one)
|
reading*: bool # is there an ongoing read? (only allow one)
|
||||||
pushedEof*: bool # eof marker has been put on readQueue
|
pushedEof*: bool # eof marker has been put on readQueue
|
||||||
returnedEof*: bool # 0-byte readOnce has been completed
|
returnedEof*: bool # 0-byte readOnce has been completed
|
||||||
@ -63,6 +63,8 @@ method pushData*(s: BufferStream, data: seq[byte]) {.base, async.} =
|
|||||||
##
|
##
|
||||||
## `pushTo` will block if the queue is full, thus maintaining backpressure.
|
## `pushTo` will block if the queue is full, thus maintaining backpressure.
|
||||||
##
|
##
|
||||||
|
|
||||||
|
doAssert(not s.pushing, "Only one concurrent push allowed")
|
||||||
if s.isClosed or s.pushedEof:
|
if s.isClosed or s.pushedEof:
|
||||||
raise newLPStreamEOFError()
|
raise newLPStreamEOFError()
|
||||||
|
|
||||||
@ -71,26 +73,29 @@ method pushData*(s: BufferStream, data: seq[byte]) {.base, async.} =
|
|||||||
|
|
||||||
# We will block here if there is already data queued, until it has been
|
# We will block here if there is already data queued, until it has been
|
||||||
# processed
|
# processed
|
||||||
inc s.pushing
|
|
||||||
try:
|
try:
|
||||||
|
s.pushing = true
|
||||||
trace "Pushing data", s, data = data.len
|
trace "Pushing data", s, data = data.len
|
||||||
await s.readQueue.addLast(data)
|
await s.readQueue.addLast(data)
|
||||||
finally:
|
finally:
|
||||||
dec s.pushing
|
s.pushing = false
|
||||||
|
|
||||||
method pushEof*(s: BufferStream) {.base, async.} =
|
method pushEof*(s: BufferStream) {.base, async.} =
|
||||||
if s.pushedEof:
|
if s.pushedEof:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
doAssert(not s.pushing, "Only one concurrent push allowed")
|
||||||
|
|
||||||
s.pushedEof = true
|
s.pushedEof = true
|
||||||
|
|
||||||
# We will block here if there is already data queued, until it has been
|
# We will block here if there is already data queued, until it has been
|
||||||
# processed
|
# processed
|
||||||
inc s.pushing
|
|
||||||
try:
|
try:
|
||||||
|
s.pushing = true
|
||||||
trace "Pushing EOF", s
|
trace "Pushing EOF", s
|
||||||
await s.readQueue.addLast(@[])
|
await s.readQueue.addLast(Eof)
|
||||||
finally:
|
finally:
|
||||||
dec s.pushing
|
s.pushing = false
|
||||||
|
|
||||||
method atEof*(s: BufferStream): bool =
|
method atEof*(s: BufferStream): bool =
|
||||||
s.isEof and s.readBuf.len == 0
|
s.isEof and s.readBuf.len == 0
|
||||||
@ -159,8 +164,36 @@ method closeImpl*(s: BufferStream): Future[void] =
|
|||||||
## close the stream and clear the buffer
|
## close the stream and clear the buffer
|
||||||
trace "Closing BufferStream", s, len = s.len
|
trace "Closing BufferStream", s, len = s.len
|
||||||
|
|
||||||
if not s.pushedEof: # Potentially wake up reader
|
# First, make sure any new calls to `readOnce` and `pushData` etc will fail -
|
||||||
asyncSpawn s.pushEof()
|
# there may already be such calls in the event queue however
|
||||||
|
s.isEof = true
|
||||||
|
s.readBuf = StreamSeq()
|
||||||
|
s.pushedEof = true
|
||||||
|
|
||||||
|
# Essentially we need to handle the following cases
|
||||||
|
#
|
||||||
|
# - If a push was in progress but no reader is
|
||||||
|
# attached we need to pop the queue
|
||||||
|
# - If a read was in progress without without a
|
||||||
|
# push/data we need to push the Eof marker to
|
||||||
|
# notify the reader that the channel closed
|
||||||
|
#
|
||||||
|
# In all other cases, there should be a data to complete
|
||||||
|
# a read or enough room in the queue/buffer to complete a
|
||||||
|
# push.
|
||||||
|
#
|
||||||
|
# State | Q Empty | Q Full
|
||||||
|
# ------------|----------|-------
|
||||||
|
# Reading | Push Eof | Na
|
||||||
|
# Pushing | Na | Pop
|
||||||
|
if not(s.reading and s.pushing):
|
||||||
|
if s.reading:
|
||||||
|
if s.readQueue.empty():
|
||||||
|
# There is an active reader
|
||||||
|
s.readQueue.addLastNoWait(Eof)
|
||||||
|
elif s.pushing:
|
||||||
|
if not s.readQueue.empty():
|
||||||
|
discard s.readQueue.popFirstNoWait()
|
||||||
|
|
||||||
trace "Closed BufferStream", s
|
trace "Closed BufferStream", s
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ logScope:
|
|||||||
|
|
||||||
const
|
const
|
||||||
LPStreamTrackerName* = "LPStream"
|
LPStreamTrackerName* = "LPStream"
|
||||||
|
Eof* = @[]
|
||||||
|
|
||||||
type
|
type
|
||||||
Direction* {.pure.} = enum
|
Direction* {.pure.} = enum
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import chronos, stew/byteutils
|
import chronos, stew/byteutils
|
||||||
import ../libp2p/stream/bufferstream,
|
import ../libp2p/stream/bufferstream,
|
||||||
../libp2p/stream/lpstream
|
../libp2p/stream/lpstream,
|
||||||
|
../libp2p/errors
|
||||||
|
|
||||||
import ./helpers
|
import ./helpers
|
||||||
|
|
||||||
@ -87,10 +88,12 @@ suite "BufferStream":
|
|||||||
let buff = newBufferStream()
|
let buff = newBufferStream()
|
||||||
check buff.len == 0
|
check buff.len == 0
|
||||||
|
|
||||||
let w1 = buff.pushData("Msg 1".toBytes())
|
proc writer1() {.async.} =
|
||||||
let w2 = buff.pushData("Msg 2".toBytes())
|
await buff.pushData("Msg 1".toBytes())
|
||||||
let w3 = buff.pushData("Msg 3".toBytes())
|
await buff.pushData("Msg 2".toBytes())
|
||||||
|
await buff.pushData("Msg 3".toBytes())
|
||||||
|
|
||||||
|
let writerFut1 = writer1()
|
||||||
var data: array[5, byte]
|
var data: array[5, byte]
|
||||||
await buff.readExactly(addr data[0], data.len)
|
await buff.readExactly(addr data[0], data.len)
|
||||||
|
|
||||||
@ -102,13 +105,14 @@ suite "BufferStream":
|
|||||||
await buff.readExactly(addr data[0], data.len)
|
await buff.readExactly(addr data[0], data.len)
|
||||||
check string.fromBytes(data) == "Msg 3"
|
check string.fromBytes(data) == "Msg 3"
|
||||||
|
|
||||||
for f in [w1, w2, w3]: await f
|
await writerFut1
|
||||||
|
|
||||||
let w4 = buff.pushData("Msg 4".toBytes())
|
proc writer2() {.async.} =
|
||||||
let w5 = buff.pushData("Msg 5".toBytes())
|
await buff.pushData("Msg 4".toBytes())
|
||||||
let w6 = buff.pushData("Msg 6".toBytes())
|
await buff.pushData("Msg 5".toBytes())
|
||||||
|
await buff.pushData("Msg 6".toBytes())
|
||||||
|
|
||||||
await buff.close()
|
let writerFut2 = writer2()
|
||||||
|
|
||||||
await buff.readExactly(addr data[0], data.len)
|
await buff.readExactly(addr data[0], data.len)
|
||||||
check string.fromBytes(data) == "Msg 4"
|
check string.fromBytes(data) == "Msg 4"
|
||||||
@ -118,27 +122,33 @@ suite "BufferStream":
|
|||||||
|
|
||||||
await buff.readExactly(addr data[0], data.len)
|
await buff.readExactly(addr data[0], data.len)
|
||||||
check string.fromBytes(data) == "Msg 6"
|
check string.fromBytes(data) == "Msg 6"
|
||||||
for f in [w4, w5, w6]: await f
|
|
||||||
|
await buff.close()
|
||||||
|
await writerFut2
|
||||||
|
|
||||||
asyncTest "small reads":
|
asyncTest "small reads":
|
||||||
let buff = newBufferStream()
|
let buff = newBufferStream()
|
||||||
check buff.len == 0
|
check buff.len == 0
|
||||||
|
|
||||||
var writes: seq[Future[void]]
|
|
||||||
var str: string
|
var str: string
|
||||||
|
proc writer() {.async.} =
|
||||||
for i in 0..<10:
|
for i in 0..<10:
|
||||||
writes.add buff.pushData("123".toBytes())
|
await buff.pushData("123".toBytes())
|
||||||
str &= "123"
|
str &= "123"
|
||||||
await buff.close() # all data should still be read after close
|
await buff.close() # all data should still be read after close
|
||||||
|
|
||||||
var str2: string
|
var str2: string
|
||||||
|
|
||||||
|
proc reader() {.async.} =
|
||||||
var data: array[2, byte]
|
var data: array[2, byte]
|
||||||
expect LPStreamEOFError:
|
expect LPStreamEOFError:
|
||||||
while true:
|
while true:
|
||||||
let x = await buff.readOnce(addr data[0], data.len)
|
let x = await buff.readOnce(addr data[0], data.len)
|
||||||
str2 &= string.fromBytes(data[0..<x])
|
str2 &= string.fromBytes(data[0..<x])
|
||||||
|
|
||||||
for f in writes: await f
|
|
||||||
|
await allFuturesThrowing(
|
||||||
|
allFinished(reader(), writer()))
|
||||||
check str == str2
|
check str == str2
|
||||||
await buff.close()
|
await buff.close()
|
||||||
|
|
||||||
@ -196,9 +206,11 @@ suite "BufferStream":
|
|||||||
fut = stream.pushData(toBytes("hello"))
|
fut = stream.pushData(toBytes("hello"))
|
||||||
fut2 = stream.pushData(toBytes("again"))
|
fut2 = stream.pushData(toBytes("again"))
|
||||||
await stream.close()
|
await stream.close()
|
||||||
expect AsyncTimeoutError:
|
|
||||||
await wait(fut, 100.milliseconds)
|
# Both writes should be completed on close (technically, the should maybe
|
||||||
await wait(fut2, 100.milliseconds)
|
# be cancelled, at least the second one...
|
||||||
|
check await fut.withTimeout(100.milliseconds)
|
||||||
|
check await fut2.withTimeout(100.milliseconds)
|
||||||
|
|
||||||
await stream.close()
|
await stream.close()
|
||||||
|
|
||||||
@ -211,3 +223,14 @@ suite "BufferStream":
|
|||||||
|
|
||||||
expect LPStreamEOFError:
|
expect LPStreamEOFError:
|
||||||
await stream.pushData("123".toBytes())
|
await stream.pushData("123".toBytes())
|
||||||
|
|
||||||
|
asyncTest "no concurrent pushes":
|
||||||
|
var stream = newBufferStream()
|
||||||
|
await stream.pushData("123".toBytes())
|
||||||
|
let push = stream.pushData("123".toBytes())
|
||||||
|
|
||||||
|
expect AssertionError:
|
||||||
|
await stream.pushData("123".toBytes())
|
||||||
|
|
||||||
|
await stream.closeWithEOF()
|
||||||
|
await push
|
||||||
|
@ -215,16 +215,17 @@ suite "Mplex":
|
|||||||
conn = newBufferStream(writeHandler)
|
conn = newBufferStream(writeHandler)
|
||||||
chann = LPChannel.init(1, conn, true)
|
chann = LPChannel.init(1, conn, true)
|
||||||
|
|
||||||
let futs = @[
|
proc pushes() {.async.} = # pushes don't hang on reset
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
]
|
|
||||||
|
let push = pushes()
|
||||||
await chann.reset()
|
await chann.reset()
|
||||||
check await allFutures(futs).withTimeout(100.millis)
|
check await allFutures(push).withTimeout(100.millis)
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
asyncTest "reset should complete both read and push":
|
asyncTest "reset should complete both read and push":
|
||||||
@ -249,23 +250,22 @@ suite "Mplex":
|
|||||||
chann = LPChannel.init(1, conn, true)
|
chann = LPChannel.init(1, conn, true)
|
||||||
|
|
||||||
var data = newSeq[byte](1)
|
var data = newSeq[byte](1)
|
||||||
let futs = [
|
let read = chann.readExactly(addr data[0], 1)
|
||||||
chann.readExactly(addr data[0], 1),
|
proc pushes() {.async.} =
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
chann.pushData(@[0'u8]),
|
await chann.pushData(@[0'u8])
|
||||||
]
|
|
||||||
await chann.reset()
|
await chann.reset()
|
||||||
check await allFutures(futs).withTimeout(100.millis)
|
check await allFutures(read, pushes()).withTimeout(100.millis)
|
||||||
await futs[0]
|
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
asyncTest "reset should complete both read and push with cancel":
|
asyncTest "reset should complete both read and push with cancel":
|
||||||
@ -300,6 +300,59 @@ suite "Mplex":
|
|||||||
check await allFutures(rfut, rfut2, wfut, wfut2).withTimeout(100.millis)
|
check await allFutures(rfut, rfut2, wfut, wfut2).withTimeout(100.millis)
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
|
asyncTest "reset should complete ongoing push without reader":
|
||||||
|
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||||
|
let
|
||||||
|
conn = newBufferStream(writeHandler)
|
||||||
|
chann = LPChannel.init(1, conn, true)
|
||||||
|
|
||||||
|
await chann.pushData(@[0'u8])
|
||||||
|
let push1 = chann.pushData(@[0'u8])
|
||||||
|
await chann.reset()
|
||||||
|
check await allFutures(push1).withTimeout(100.millis)
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
asyncTest "reset should complete ongoing read without a push":
|
||||||
|
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||||
|
let
|
||||||
|
conn = newBufferStream(writeHandler)
|
||||||
|
chann = LPChannel.init(1, conn, true)
|
||||||
|
|
||||||
|
var data = newSeq[byte](1)
|
||||||
|
let rfut = chann.readExactly(addr data[0], 1)
|
||||||
|
await chann.reset()
|
||||||
|
check await allFutures(rfut).withTimeout(100.millis)
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
asyncTest "reset should allow all reads and pushes to complete":
|
||||||
|
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||||
|
let
|
||||||
|
conn = newBufferStream(writeHandler)
|
||||||
|
chann = LPChannel.init(1, conn, true)
|
||||||
|
|
||||||
|
var data = newSeq[byte](1)
|
||||||
|
proc writer() {.async.} =
|
||||||
|
await chann.pushData(@[0'u8])
|
||||||
|
await chann.pushData(@[0'u8])
|
||||||
|
await chann.pushData(@[0'u8])
|
||||||
|
|
||||||
|
proc reader() {.async.} =
|
||||||
|
await chann.readExactly(addr data[0], 1)
|
||||||
|
await chann.readExactly(addr data[0], 1)
|
||||||
|
await chann.readExactly(addr data[0], 1)
|
||||||
|
|
||||||
|
let rw = @[writer(), reader()]
|
||||||
|
|
||||||
|
await chann.close()
|
||||||
|
check await chann.reset() # this would hang
|
||||||
|
.withTimeout(100.millis)
|
||||||
|
|
||||||
|
check await allFuturesThrowing(
|
||||||
|
allFinished(rw))
|
||||||
|
.withTimeout(100.millis)
|
||||||
|
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
asyncTest "channel should fail writing":
|
asyncTest "channel should fail writing":
|
||||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||||
let
|
let
|
||||||
|
Loading…
x
Reference in New Issue
Block a user