mirror of https://github.com/vacp2p/nim-libp2p.git
feat: half closed channels
This commit is contained in:
parent
6058a3fc69
commit
9b485b3082
|
@ -8,27 +8,123 @@
|
||||||
## those terms.
|
## those terms.
|
||||||
|
|
||||||
import chronos
|
import chronos
|
||||||
import ../../stream/bufferstream
|
import ../../stream/bufferstream,
|
||||||
import types
|
../../stream/lpstream,
|
||||||
|
types, coder, ../../connection
|
||||||
|
|
||||||
type
|
type
|
||||||
Channel* = ref object of BufferStream
|
Channel* = ref object of BufferStream
|
||||||
id*: int
|
id*: int
|
||||||
|
conn*: Connection
|
||||||
initiator*: bool
|
initiator*: bool
|
||||||
isReset*: bool
|
isReset*: bool
|
||||||
closedLocal*: bool
|
closedLocal*: bool
|
||||||
closedRemote*: bool
|
closedRemote*: bool
|
||||||
handlerFuture*: Future[void]
|
handlerFuture*: Future[void]
|
||||||
|
msgCode*: MessageType
|
||||||
|
closeCode*: MessageType
|
||||||
|
resetCode*: MessageType
|
||||||
|
|
||||||
proc newChannel*(id: int,
|
proc newChannel*(id: int,
|
||||||
|
conn: Connection,
|
||||||
initiator: bool,
|
initiator: bool,
|
||||||
handler: WriteHandler,
|
|
||||||
size: int = MaxMsgSize): Channel =
|
size: int = MaxMsgSize): Channel =
|
||||||
new result
|
new result
|
||||||
result.id = id
|
result.id = id
|
||||||
|
result.conn = conn
|
||||||
result.initiator = initiator
|
result.initiator = initiator
|
||||||
result.initBufferStream(handler, size)
|
result.msgCode = if initiator: MessageType.MsgOut else: MessageType.MsgIn
|
||||||
|
result.closeCode = if initiator: MessageType.CloseOut else: MessageType.CloseIn
|
||||||
|
result.resetCode = if initiator: MessageType.ResetOut else: MessageType.ResetIn
|
||||||
|
|
||||||
proc closed*(s: Channel): bool = s.closedLocal and s.closedRemote
|
let chan = result
|
||||||
proc close*(s: Channel) {.async.} = discard
|
proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} =
|
||||||
proc reset*(s: Channel) {.async.} = discard
|
await conn.writeHeader(id, chan.msgCode, data.len) # write header
|
||||||
|
await conn.write(data)
|
||||||
|
|
||||||
|
result.initBufferStream(writeHandler, size)
|
||||||
|
|
||||||
|
proc closeMessage(s: Channel) {.async, gcsafe.} =
|
||||||
|
await s.conn.writeHeader(s.id, s.closeCode, 0) # write header
|
||||||
|
|
||||||
|
proc closed*(s: Channel): bool =
|
||||||
|
s.closedLocal
|
||||||
|
|
||||||
|
proc closeRemote*(s: Channel) {.async.} =
|
||||||
|
s.closedRemote = true
|
||||||
|
|
||||||
|
method close*(s: Channel) {.async, gcsafe.} =
|
||||||
|
s.closedLocal = true
|
||||||
|
await s.closeMessage()
|
||||||
|
|
||||||
|
proc resetMessage(s: Channel) {.async, gcsafe.} =
|
||||||
|
await s.conn.writeHeader(s.id, s.resetCode, 0) # write header
|
||||||
|
|
||||||
|
proc remoteReset*(s: Channel) {.async, gcsafe.} =
|
||||||
|
await allFutures(s.close(), s.closeRemote())
|
||||||
|
s.isReset = true
|
||||||
|
|
||||||
|
proc reset*(s: Channel) {.async.} =
|
||||||
|
await allFutures(s.resetMessage(), s.remoteReset())
|
||||||
|
|
||||||
|
proc isReadEof(s: Channel): bool =
|
||||||
|
bool((s.closedRemote or s.closedLocal) and s.len() <= 0)
|
||||||
|
|
||||||
|
method pushTo*(s: Channel, data: seq[byte]): Future[void] {.gcsafe.} =
|
||||||
|
if s.closedRemote:
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
result = procCall pushTo(BufferStream(s), data)
|
||||||
|
|
||||||
|
method read*(s: Channel, n = -1): Future[seq[byte]] {.gcsafe.} =
|
||||||
|
if s.isReadEof():
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
result = procCall read(BufferStream(s), n)
|
||||||
|
|
||||||
|
method readExactly*(s: Channel,
|
||||||
|
pbytes: pointer,
|
||||||
|
nbytes: int):
|
||||||
|
Future[void] {.gcsafe.} =
|
||||||
|
if s.isReadEof():
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
result = procCall readExactly(BufferStream(s), pbytes, nbytes)
|
||||||
|
|
||||||
|
method readLine*(s: Channel,
|
||||||
|
limit = 0,
|
||||||
|
sep = "\r\n"):
|
||||||
|
Future[string] {.gcsafe.} =
|
||||||
|
if s.isReadEof():
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
result = procCall readLine(BufferStream(s), limit, sep)
|
||||||
|
|
||||||
|
method readOnce*(s: Channel,
|
||||||
|
pbytes: pointer,
|
||||||
|
nbytes: int):
|
||||||
|
Future[int] {.gcsafe.} =
|
||||||
|
if s.isReadEof():
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
result = procCall readOnce(BufferStream(s), pbytes, nbytes)
|
||||||
|
|
||||||
|
method readUntil*(s: Channel,
|
||||||
|
pbytes: pointer, nbytes: int,
|
||||||
|
sep: seq[byte]):
|
||||||
|
Future[int] {.gcsafe.} =
|
||||||
|
if s.isReadEof():
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
result = procCall readOnce(BufferStream(s), pbytes, nbytes)
|
||||||
|
|
||||||
|
method write*(s: Channel,
|
||||||
|
pbytes: pointer,
|
||||||
|
nbytes: int): Future[void] {.gcsafe.} =
|
||||||
|
if s.closedLocal:
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
result = procCall write(BufferStream(s), pbytes, nbytes)
|
||||||
|
|
||||||
|
method write*(s: Channel, msg: string, msglen = -1) {.async, gcsafe.} =
|
||||||
|
if s.closedLocal:
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
result = procCall write(BufferStream(s), msg, msglen)
|
||||||
|
|
||||||
|
method write*(s: Channel, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
|
||||||
|
if s.closedLocal:
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
result = procCall write(BufferStream(s), msg, msglen)
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
|
|
||||||
import chronos
|
import chronos
|
||||||
import ../../connection, ../../varint,
|
import ../../connection, ../../varint,
|
||||||
../../vbuffer, mplex, types,
|
../../vbuffer, types,
|
||||||
../../stream/lpstream
|
../../stream/lpstream
|
||||||
|
|
||||||
proc readHeader*(conn: Connection): Future[(uint, MessageType)] {.async, gcsafe.} =
|
proc readHeader*(conn: Connection): Future[(uint, MessageType)] {.async, gcsafe.} =
|
||||||
|
|
|
@ -26,10 +26,6 @@ type
|
||||||
proc newMplexUnknownMsgError(): ref MplexUnknownMsgError =
|
proc newMplexUnknownMsgError(): ref MplexUnknownMsgError =
|
||||||
result = newException(MplexUnknownMsgError, "Unknown mplex message type")
|
result = newException(MplexUnknownMsgError, "Unknown mplex message type")
|
||||||
|
|
||||||
##########################################
|
|
||||||
## Mplex
|
|
||||||
##########################################
|
|
||||||
|
|
||||||
proc getChannelList(m: Mplex, initiator: bool): var Table[int, Channel] =
|
proc getChannelList(m: Mplex, initiator: bool): var Table[int, Channel] =
|
||||||
if initiator:
|
if initiator:
|
||||||
result = m.remote
|
result = m.remote
|
||||||
|
@ -42,12 +38,7 @@ proc newStreamInternal*(m: Mplex,
|
||||||
Future[Channel] {.async, gcsafe.} =
|
Future[Channel] {.async, gcsafe.} =
|
||||||
## create new channel/stream
|
## create new channel/stream
|
||||||
let id = if initiator: m.currentId.inc(); m.currentId else: chanId
|
let id = if initiator: m.currentId.inc(); m.currentId else: chanId
|
||||||
proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} =
|
result = newChannel(id, m.connection, initiator)
|
||||||
let msgType = if initiator: MessageType.MsgOut else: MessageType.MsgIn
|
|
||||||
await m.connection.writeHeader(id, msgType, data.len) # write header
|
|
||||||
await m.connection.write(data) # write data
|
|
||||||
|
|
||||||
result = newChannel(id, initiator, writeHandler)
|
|
||||||
m.getChannelList(initiator)[id] = result
|
m.getChannelList(initiator)[id] = result
|
||||||
|
|
||||||
proc newStreamInternal*(m: Mplex): Future[Channel] {.gcsafe.} =
|
proc newStreamInternal*(m: Mplex): Future[Channel] {.gcsafe.} =
|
||||||
|
@ -68,7 +59,7 @@ proc handle*(m: Mplex): Future[void] {.async, gcsafe.} =
|
||||||
await channel.pushTo(msg)
|
await channel.pushTo(msg)
|
||||||
of MessageType.CloseIn, MessageType.CloseOut:
|
of MessageType.CloseIn, MessageType.CloseOut:
|
||||||
let channel = m.getChannelList(initiator)[id.int]
|
let channel = m.getChannelList(initiator)[id.int]
|
||||||
await channel.close()
|
await channel.closeRemote()
|
||||||
of MessageType.ResetIn, MessageType.ResetOut:
|
of MessageType.ResetIn, MessageType.ResetOut:
|
||||||
let channel = m.getChannelList(initiator)[id.int]
|
let channel = m.getChannelList(initiator)[id.int]
|
||||||
await channel.reset()
|
await channel.reset()
|
||||||
|
|
|
@ -13,6 +13,7 @@ import ../../connection
|
||||||
const MaxMsgSize* = 1 shl 20 # 1mb
|
const MaxMsgSize* = 1 shl 20 # 1mb
|
||||||
const MaxChannels* = 1000
|
const MaxChannels* = 1000
|
||||||
const MplexCodec* = "/mplex/6.7.0"
|
const MplexCodec* = "/mplex/6.7.0"
|
||||||
|
const MaxReadWriteTime* = 5.seconds
|
||||||
|
|
||||||
type
|
type
|
||||||
MplexUnknownMsgError* = object of CatchableError
|
MplexUnknownMsgError* = object of CatchableError
|
||||||
|
|
|
@ -21,6 +21,7 @@ type
|
||||||
par*: ref Exception
|
par*: ref Exception
|
||||||
LPStreamWriteError* = object of LPStreamError
|
LPStreamWriteError* = object of LPStreamError
|
||||||
par*: ref Exception
|
par*: ref Exception
|
||||||
|
LPStreamClosedError* = object of LPStreamError
|
||||||
|
|
||||||
proc newLPStreamReadError*(p: ref Exception): ref Exception {.inline.} =
|
proc newLPStreamReadError*(p: ref Exception): ref Exception {.inline.} =
|
||||||
var w = newException(LPStreamReadError, "Read stream failed")
|
var w = newException(LPStreamReadError, "Read stream failed")
|
||||||
|
@ -43,6 +44,9 @@ proc newLPStreamLimitError*(): ref Exception {.inline.} =
|
||||||
proc newLPStreamIncorrectError*(m: string): ref Exception {.inline.} =
|
proc newLPStreamIncorrectError*(m: string): ref Exception {.inline.} =
|
||||||
result = newException(LPStreamIncorrectError, m)
|
result = newException(LPStreamIncorrectError, m)
|
||||||
|
|
||||||
|
proc newLPStreamClosedError*(): ref Exception {.inline.} =
|
||||||
|
result = newException(LPStreamClosedError, "Stream closed!")
|
||||||
|
|
||||||
method read*(s: LPStream, n = -1): Future[seq[byte]]
|
method read*(s: LPStream, n = -1): Future[seq[byte]]
|
||||||
{.base, async, gcsafe.} =
|
{.base, async, gcsafe.} =
|
||||||
assert(false, "not implemented!")
|
assert(false, "not implemented!")
|
||||||
|
|
|
@ -2,6 +2,7 @@ import unittest, sequtils, sugar
|
||||||
import chronos, nimcrypto/utils
|
import chronos, nimcrypto/utils
|
||||||
import ../libp2p/connection,
|
import ../libp2p/connection,
|
||||||
../libp2p/stream/lpstream,
|
../libp2p/stream/lpstream,
|
||||||
|
../libp2p/stream/bufferstream,
|
||||||
../libp2p/tcptransport,
|
../libp2p/tcptransport,
|
||||||
../libp2p/transport,
|
../libp2p/transport,
|
||||||
../libp2p/multiaddress,
|
../libp2p/multiaddress,
|
||||||
|
@ -92,7 +93,6 @@ suite "Mplex":
|
||||||
let dialFut = mplexDial.handle()
|
let dialFut = mplexDial.handle()
|
||||||
let stream = await mplexDial.newStream()
|
let stream = await mplexDial.newStream()
|
||||||
check cast[string](await stream.readLp()) == "Hello from stream!"
|
check cast[string](await stream.readLp()) == "Hello from stream!"
|
||||||
|
|
||||||
await conn.close()
|
await conn.close()
|
||||||
await dialFut
|
await dialFut
|
||||||
result = true
|
result = true
|
||||||
|
@ -130,3 +130,51 @@ suite "Mplex":
|
||||||
|
|
||||||
check:
|
check:
|
||||||
waitFor(testNewStream()) == true
|
waitFor(testNewStream()) == true
|
||||||
|
|
||||||
|
test "half closed - channel should close for write":
|
||||||
|
proc testClosedForWrite(): Future[void] {.async.} =
|
||||||
|
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||||
|
await chann.close()
|
||||||
|
await chann.write("Hello")
|
||||||
|
|
||||||
|
expect LPStreamClosedError:
|
||||||
|
waitFor(testClosedForWrite())
|
||||||
|
|
||||||
|
test "half closed - channel should close for read":
|
||||||
|
proc testClosedForRead(): Future[void] {.async.} =
|
||||||
|
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||||
|
await chann.closeRemote()
|
||||||
|
asyncDiscard chann.read()
|
||||||
|
|
||||||
|
expect LPStreamClosedError:
|
||||||
|
waitFor(testClosedForRead())
|
||||||
|
|
||||||
|
test "half closed - channel should close for read after eof":
|
||||||
|
proc testClosedForRead(): Future[void] {.async.} =
|
||||||
|
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||||
|
|
||||||
|
await chann.pushTo(cast[seq[byte]](toSeq("Hello!".items)))
|
||||||
|
await chann.close()
|
||||||
|
let msg = await chann.read()
|
||||||
|
asyncDiscard chann.read()
|
||||||
|
|
||||||
|
expect LPStreamClosedError:
|
||||||
|
waitFor(testClosedForRead())
|
||||||
|
|
||||||
|
test "reset - channel should fail reading":
|
||||||
|
proc testResetRead(): Future[void] {.async.} =
|
||||||
|
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||||
|
await chann.reset()
|
||||||
|
asyncDiscard chann.read()
|
||||||
|
|
||||||
|
expect LPStreamClosedError:
|
||||||
|
waitFor(testResetRead())
|
||||||
|
|
||||||
|
test "reset - channel should fail writing":
|
||||||
|
proc testResetWrite(): Future[void] {.async.} =
|
||||||
|
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||||
|
await chann.reset()
|
||||||
|
asyncDiscard chann.read()
|
||||||
|
|
||||||
|
expect LPStreamClosedError:
|
||||||
|
waitFor(testResetWrite())
|
||||||
|
|
Loading…
Reference in New Issue