mirror of https://github.com/vacp2p/nim-libp2p.git
feat: half closed channels
This commit is contained in:
parent
6058a3fc69
commit
9b485b3082
|
@ -8,27 +8,123 @@
|
|||
## those terms.
|
||||
|
||||
import chronos
|
||||
import ../../stream/bufferstream
|
||||
import types
|
||||
import ../../stream/bufferstream,
|
||||
../../stream/lpstream,
|
||||
types, coder, ../../connection
|
||||
|
||||
type
|
||||
Channel* = ref object of BufferStream
|
||||
id*: int
|
||||
conn*: Connection
|
||||
initiator*: bool
|
||||
isReset*: bool
|
||||
closedLocal*: bool
|
||||
closedRemote*: bool
|
||||
handlerFuture*: Future[void]
|
||||
msgCode*: MessageType
|
||||
closeCode*: MessageType
|
||||
resetCode*: MessageType
|
||||
|
||||
proc newChannel*(id: int,
|
||||
conn: Connection,
|
||||
initiator: bool,
|
||||
handler: WriteHandler,
|
||||
size: int = MaxMsgSize): Channel =
|
||||
new result
|
||||
result.id = id
|
||||
result.conn = conn
|
||||
result.initiator = initiator
|
||||
result.initBufferStream(handler, size)
|
||||
result.msgCode = if initiator: MessageType.MsgOut else: MessageType.MsgIn
|
||||
result.closeCode = if initiator: MessageType.CloseOut else: MessageType.CloseIn
|
||||
result.resetCode = if initiator: MessageType.ResetOut else: MessageType.ResetIn
|
||||
|
||||
proc closed*(s: Channel): bool = s.closedLocal and s.closedRemote
|
||||
proc close*(s: Channel) {.async.} = discard
|
||||
proc reset*(s: Channel) {.async.} = discard
|
||||
let chan = result
|
||||
proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} =
|
||||
await conn.writeHeader(id, chan.msgCode, data.len) # write header
|
||||
await conn.write(data)
|
||||
|
||||
result.initBufferStream(writeHandler, size)
|
||||
|
||||
proc closeMessage(s: Channel) {.async, gcsafe.} =
|
||||
await s.conn.writeHeader(s.id, s.closeCode, 0) # write header
|
||||
|
||||
proc closed*(s: Channel): bool =
|
||||
s.closedLocal
|
||||
|
||||
proc closeRemote*(s: Channel) {.async.} =
|
||||
s.closedRemote = true
|
||||
|
||||
method close*(s: Channel) {.async, gcsafe.} =
|
||||
s.closedLocal = true
|
||||
await s.closeMessage()
|
||||
|
||||
proc resetMessage(s: Channel) {.async, gcsafe.} =
|
||||
await s.conn.writeHeader(s.id, s.resetCode, 0) # write header
|
||||
|
||||
proc remoteReset*(s: Channel) {.async, gcsafe.} =
|
||||
await allFutures(s.close(), s.closeRemote())
|
||||
s.isReset = true
|
||||
|
||||
proc reset*(s: Channel) {.async.} =
|
||||
await allFutures(s.resetMessage(), s.remoteReset())
|
||||
|
||||
proc isReadEof(s: Channel): bool =
|
||||
bool((s.closedRemote or s.closedLocal) and s.len() <= 0)
|
||||
|
||||
method pushTo*(s: Channel, data: seq[byte]): Future[void] {.gcsafe.} =
|
||||
if s.closedRemote:
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall pushTo(BufferStream(s), data)
|
||||
|
||||
method read*(s: Channel, n = -1): Future[seq[byte]] {.gcsafe.} =
|
||||
if s.isReadEof():
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall read(BufferStream(s), n)
|
||||
|
||||
method readExactly*(s: Channel,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[void] {.gcsafe.} =
|
||||
if s.isReadEof():
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall readExactly(BufferStream(s), pbytes, nbytes)
|
||||
|
||||
method readLine*(s: Channel,
|
||||
limit = 0,
|
||||
sep = "\r\n"):
|
||||
Future[string] {.gcsafe.} =
|
||||
if s.isReadEof():
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall readLine(BufferStream(s), limit, sep)
|
||||
|
||||
method readOnce*(s: Channel,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[int] {.gcsafe.} =
|
||||
if s.isReadEof():
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall readOnce(BufferStream(s), pbytes, nbytes)
|
||||
|
||||
method readUntil*(s: Channel,
|
||||
pbytes: pointer, nbytes: int,
|
||||
sep: seq[byte]):
|
||||
Future[int] {.gcsafe.} =
|
||||
if s.isReadEof():
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall readOnce(BufferStream(s), pbytes, nbytes)
|
||||
|
||||
method write*(s: Channel,
|
||||
pbytes: pointer,
|
||||
nbytes: int): Future[void] {.gcsafe.} =
|
||||
if s.closedLocal:
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall write(BufferStream(s), pbytes, nbytes)
|
||||
|
||||
method write*(s: Channel, msg: string, msglen = -1) {.async, gcsafe.} =
|
||||
if s.closedLocal:
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall write(BufferStream(s), msg, msglen)
|
||||
|
||||
method write*(s: Channel, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
|
||||
if s.closedLocal:
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall write(BufferStream(s), msg, msglen)
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
import chronos
|
||||
import ../../connection, ../../varint,
|
||||
../../vbuffer, mplex, types,
|
||||
../../vbuffer, types,
|
||||
../../stream/lpstream
|
||||
|
||||
proc readHeader*(conn: Connection): Future[(uint, MessageType)] {.async, gcsafe.} =
|
||||
|
|
|
@ -26,10 +26,6 @@ type
|
|||
proc newMplexUnknownMsgError(): ref MplexUnknownMsgError =
|
||||
result = newException(MplexUnknownMsgError, "Unknown mplex message type")
|
||||
|
||||
##########################################
|
||||
## Mplex
|
||||
##########################################
|
||||
|
||||
proc getChannelList(m: Mplex, initiator: bool): var Table[int, Channel] =
|
||||
if initiator:
|
||||
result = m.remote
|
||||
|
@ -42,12 +38,7 @@ proc newStreamInternal*(m: Mplex,
|
|||
Future[Channel] {.async, gcsafe.} =
|
||||
## create new channel/stream
|
||||
let id = if initiator: m.currentId.inc(); m.currentId else: chanId
|
||||
proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} =
|
||||
let msgType = if initiator: MessageType.MsgOut else: MessageType.MsgIn
|
||||
await m.connection.writeHeader(id, msgType, data.len) # write header
|
||||
await m.connection.write(data) # write data
|
||||
|
||||
result = newChannel(id, initiator, writeHandler)
|
||||
result = newChannel(id, m.connection, initiator)
|
||||
m.getChannelList(initiator)[id] = result
|
||||
|
||||
proc newStreamInternal*(m: Mplex): Future[Channel] {.gcsafe.} =
|
||||
|
@ -68,7 +59,7 @@ proc handle*(m: Mplex): Future[void] {.async, gcsafe.} =
|
|||
await channel.pushTo(msg)
|
||||
of MessageType.CloseIn, MessageType.CloseOut:
|
||||
let channel = m.getChannelList(initiator)[id.int]
|
||||
await channel.close()
|
||||
await channel.closeRemote()
|
||||
of MessageType.ResetIn, MessageType.ResetOut:
|
||||
let channel = m.getChannelList(initiator)[id.int]
|
||||
await channel.reset()
|
||||
|
|
|
@ -13,6 +13,7 @@ import ../../connection
|
|||
const MaxMsgSize* = 1 shl 20 # 1mb
|
||||
const MaxChannels* = 1000
|
||||
const MplexCodec* = "/mplex/6.7.0"
|
||||
const MaxReadWriteTime* = 5.seconds
|
||||
|
||||
type
|
||||
MplexUnknownMsgError* = object of CatchableError
|
||||
|
|
|
@ -21,6 +21,7 @@ type
|
|||
par*: ref Exception
|
||||
LPStreamWriteError* = object of LPStreamError
|
||||
par*: ref Exception
|
||||
LPStreamClosedError* = object of LPStreamError
|
||||
|
||||
proc newLPStreamReadError*(p: ref Exception): ref Exception {.inline.} =
|
||||
var w = newException(LPStreamReadError, "Read stream failed")
|
||||
|
@ -43,6 +44,9 @@ proc newLPStreamLimitError*(): ref Exception {.inline.} =
|
|||
proc newLPStreamIncorrectError*(m: string): ref Exception {.inline.} =
|
||||
result = newException(LPStreamIncorrectError, m)
|
||||
|
||||
proc newLPStreamClosedError*(): ref Exception {.inline.} =
|
||||
result = newException(LPStreamClosedError, "Stream closed!")
|
||||
|
||||
method read*(s: LPStream, n = -1): Future[seq[byte]]
|
||||
{.base, async, gcsafe.} =
|
||||
assert(false, "not implemented!")
|
||||
|
|
|
@ -2,6 +2,7 @@ import unittest, sequtils, sugar
|
|||
import chronos, nimcrypto/utils
|
||||
import ../libp2p/connection,
|
||||
../libp2p/stream/lpstream,
|
||||
../libp2p/stream/bufferstream,
|
||||
../libp2p/tcptransport,
|
||||
../libp2p/transport,
|
||||
../libp2p/multiaddress,
|
||||
|
@ -92,7 +93,6 @@ suite "Mplex":
|
|||
let dialFut = mplexDial.handle()
|
||||
let stream = await mplexDial.newStream()
|
||||
check cast[string](await stream.readLp()) == "Hello from stream!"
|
||||
|
||||
await conn.close()
|
||||
await dialFut
|
||||
result = true
|
||||
|
@ -130,3 +130,51 @@ suite "Mplex":
|
|||
|
||||
check:
|
||||
waitFor(testNewStream()) == true
|
||||
|
||||
test "half closed - channel should close for write":
|
||||
proc testClosedForWrite(): Future[void] {.async.} =
|
||||
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||
await chann.close()
|
||||
await chann.write("Hello")
|
||||
|
||||
expect LPStreamClosedError:
|
||||
waitFor(testClosedForWrite())
|
||||
|
||||
test "half closed - channel should close for read":
|
||||
proc testClosedForRead(): Future[void] {.async.} =
|
||||
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||
await chann.closeRemote()
|
||||
asyncDiscard chann.read()
|
||||
|
||||
expect LPStreamClosedError:
|
||||
waitFor(testClosedForRead())
|
||||
|
||||
test "half closed - channel should close for read after eof":
|
||||
proc testClosedForRead(): Future[void] {.async.} =
|
||||
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||
|
||||
await chann.pushTo(cast[seq[byte]](toSeq("Hello!".items)))
|
||||
await chann.close()
|
||||
let msg = await chann.read()
|
||||
asyncDiscard chann.read()
|
||||
|
||||
expect LPStreamClosedError:
|
||||
waitFor(testClosedForRead())
|
||||
|
||||
test "reset - channel should fail reading":
|
||||
proc testResetRead(): Future[void] {.async.} =
|
||||
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||
await chann.reset()
|
||||
asyncDiscard chann.read()
|
||||
|
||||
expect LPStreamClosedError:
|
||||
waitFor(testResetRead())
|
||||
|
||||
test "reset - channel should fail writing":
|
||||
proc testResetWrite(): Future[void] {.async.} =
|
||||
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||
await chann.reset()
|
||||
asyncDiscard chann.read()
|
||||
|
||||
expect LPStreamClosedError:
|
||||
waitFor(testResetWrite())
|
||||
|
|
Loading…
Reference in New Issue