mirror of https://github.com/vacp2p/nim-libp2p.git
reworked to make msg reading sequential
This commit is contained in:
parent
f761a7050e
commit
e53c87e197
|
@ -18,7 +18,7 @@ const DefaultChannelSize* = DefaultBufferSize * 64 # 64kb
|
|||
|
||||
type
|
||||
Channel* = ref object of BufferStream
|
||||
id*: int
|
||||
id*: uint
|
||||
name*: string
|
||||
conn*: Connection
|
||||
initiator*: bool
|
||||
|
@ -30,7 +30,7 @@ type
|
|||
closeCode*: MessageType
|
||||
resetCode*: MessageType
|
||||
|
||||
proc newChannel*(id: int,
|
||||
proc newChannel*(id: uint,
|
||||
conn: Connection,
|
||||
initiator: bool,
|
||||
name: string = "",
|
||||
|
@ -46,13 +46,12 @@ proc newChannel*(id: int,
|
|||
|
||||
let chan = result
|
||||
proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} =
|
||||
await conn.writeHeader(chan.id, chan.msgCode, data.len) # write header
|
||||
await conn.write(data)
|
||||
await conn.writeMsg(chan.id, chan.msgCode, data) # write header
|
||||
|
||||
result.initBufferStream(writeHandler, size)
|
||||
|
||||
proc closeMessage(s: Channel) {.async, gcsafe.} =
|
||||
await s.conn.writeHeader(s.id, s.closeCode) # write header
|
||||
await s.conn.writeMsg(s.id, s.closeCode) # write header
|
||||
|
||||
proc closed*(s: Channel): bool =
|
||||
s.closedLocal
|
||||
|
@ -65,7 +64,7 @@ method close*(s: Channel) {.async, gcsafe.} =
|
|||
await s.closeMessage()
|
||||
|
||||
proc resetMessage(s: Channel) {.async, gcsafe.} =
|
||||
await s.conn.writeHeader(s.id, s.resetCode)
|
||||
await s.conn.writeMsg(s.id, s.resetCode)
|
||||
|
||||
proc resetByRemote*(s: Channel) {.async, gcsafe.} =
|
||||
await allFutures(s.close(), s.closedByRemote())
|
||||
|
|
|
@ -18,33 +18,39 @@ import types,
|
|||
type
|
||||
Phase = enum Header, Size
|
||||
|
||||
proc readHeader*(conn: Connection): Future[(uint, MessageType)] {.async, gcsafe.} =
|
||||
proc readMplexVarint(conn: Connection): Future[uint] {.async, gcsafe.} =
|
||||
var
|
||||
header: uint
|
||||
varint: uint
|
||||
length: int
|
||||
res: VarintStatus
|
||||
var buffer = newSeq[byte](10)
|
||||
try:
|
||||
for i in 0..<len(buffer):
|
||||
await conn.readExactly(addr buffer[i], 1)
|
||||
res = LP.getUVarint(buffer.toOpenArray(0, i), length, header)
|
||||
res = LP.getUVarint(buffer.toOpenArray(0, i), length, varint)
|
||||
if res == VarintStatus.Success:
|
||||
let (id, msg) = (header shr 3, MessageType(header and 0x7))
|
||||
return (header shr 3, MessageType(header and 0x7))
|
||||
return varint
|
||||
if res != VarintStatus.Success:
|
||||
buffer.setLen(0)
|
||||
return
|
||||
except TransportIncompleteError:
|
||||
except TransportIncompleteError, AsyncStreamIncompleteError:
|
||||
buffer.setLen(0)
|
||||
raise newLPStreamIncompleteError()
|
||||
|
||||
proc writeHeader*(conn: Connection,
|
||||
id: int,
|
||||
proc readMsg*(conn: Connection): Future[(uint, MessageType, seq[byte])] {.async, gcsafe.} =
|
||||
let header = await conn.readMplexVarint()
|
||||
let dataLen = await conn.readMplexVarint()
|
||||
var data: seq[byte]
|
||||
if dataLen > 0.uint:
|
||||
data = await conn.read(dataLen.int)
|
||||
result = (header shr 3, MessageType(header and 0x7), data)
|
||||
|
||||
proc writeMsg*(conn: Connection,
|
||||
id: uint,
|
||||
msgType: MessageType,
|
||||
size: int = 0) {.async, gcsafe.} =
|
||||
data: seq[byte] = @[]) {.async, gcsafe.} =
|
||||
## write lenght prefixed
|
||||
var buf = initVBuffer()
|
||||
buf.writeVarint((id.uint shl 3) or msgType.uint)
|
||||
buf.writeVarint(size.uint) # size should be always sent
|
||||
buf.writeVarint((id shl 3) or ord(msgType).uint)
|
||||
buf.writeVarint(data.len().uint) # size should be always sent
|
||||
buf.finish()
|
||||
await conn.write(buf.buffer)
|
||||
await conn.write(buf.buffer & data)
|
||||
|
|
|
@ -29,18 +29,18 @@ import coder, types, channel,
|
|||
|
||||
type
|
||||
Mplex* = ref object of Muxer
|
||||
remote*: Table[int, Channel]
|
||||
local*: Table[int, Channel]
|
||||
currentId*: int
|
||||
remote*: Table[uint, Channel]
|
||||
local*: Table[uint, Channel]
|
||||
currentId*: uint
|
||||
maxChannels*: uint
|
||||
|
||||
proc newMplexNoSuchChannel(id: int, msgType: MessageType): ref MplexNoSuchChannel =
|
||||
proc newMplexNoSuchChannel(id: uint, msgType: MessageType): ref MplexNoSuchChannel =
|
||||
result = newException(MplexNoSuchChannel, &"No such channel id {$id} and message {$msgType}")
|
||||
|
||||
proc newMplexUnknownMsgError(): ref MplexUnknownMsgError =
|
||||
result = newException(MplexUnknownMsgError, "Unknown mplex message type")
|
||||
|
||||
proc getChannelList(m: Mplex, initiator: bool): var Table[int, Channel] =
|
||||
proc getChannelList(m: Mplex, initiator: bool): var Table[uint, Channel] =
|
||||
if initiator:
|
||||
result = m.remote
|
||||
else:
|
||||
|
@ -48,7 +48,7 @@ proc getChannelList(m: Mplex, initiator: bool): var Table[int, Channel] =
|
|||
|
||||
proc newStreamInternal*(m: Mplex,
|
||||
initiator: bool = true,
|
||||
chanId: int,
|
||||
chanId: uint = 0,
|
||||
name: string = ""):
|
||||
Future[Channel] {.async, gcsafe.} =
|
||||
## create new channel/stream
|
||||
|
@ -56,60 +56,45 @@ proc newStreamInternal*(m: Mplex,
|
|||
result = newChannel(id, m.connection, initiator, name)
|
||||
m.getChannelList(initiator)[id] = result
|
||||
|
||||
proc newStreamInternal*(m: Mplex): Future[Channel] {.gcsafe.} =
|
||||
result = m.newStreamInternal(true, 0)
|
||||
|
||||
method handle*(m: Mplex): Future[void] {.async, gcsafe.} =
|
||||
try:
|
||||
while not m.connection.closed:
|
||||
let (id, msgType) = await m.connection.readHeader()
|
||||
let initiator = bool(ord(msgType) and 1)
|
||||
var channel: Channel
|
||||
if MessageType(msgType) != MessageType.New:
|
||||
let channels = m.getChannelList(initiator)
|
||||
if not channels.contains(id.int):
|
||||
raise newMplexNoSuchChannel(id.int, msgType)
|
||||
channel = channels[id.int]
|
||||
try:
|
||||
let (id, msgType, data) = await m.connection.readMsg()
|
||||
let initiator = bool(ord(msgType) and 1)
|
||||
var channel: Channel
|
||||
if MessageType(msgType) != MessageType.New:
|
||||
let channels = m.getChannelList(initiator)
|
||||
if not channels.contains(id):
|
||||
raise newMplexNoSuchChannel(id, msgType)
|
||||
channel = channels[id]
|
||||
|
||||
case msgType:
|
||||
of MessageType.New:
|
||||
var name: seq[byte]
|
||||
try:
|
||||
name = await m.connection.readLp()
|
||||
except LPStreamIncompleteError as exc:
|
||||
echo exc.msg
|
||||
except Exception as exc:
|
||||
echo exc.msg
|
||||
raise
|
||||
|
||||
let channel = await m.newStreamInternal(false, id.int, cast[string](name))
|
||||
if not isNil(m.streamHandler):
|
||||
channel.handlerFuture = m.streamHandler(newConnection(channel))
|
||||
of MessageType.MsgIn, MessageType.MsgOut:
|
||||
let msg = await m.connection.readLp()
|
||||
await channel.pushTo(msg)
|
||||
of MessageType.CloseIn, MessageType.CloseOut:
|
||||
await channel.closedByRemote()
|
||||
m.getChannelList(initiator).del(id.int)
|
||||
of MessageType.ResetIn, MessageType.ResetOut:
|
||||
await channel.resetByRemote()
|
||||
else: raise newMplexUnknownMsgError()
|
||||
finally:
|
||||
await m.connection.close()
|
||||
case msgType:
|
||||
of MessageType.New:
|
||||
channel = await m.newStreamInternal(false, id, cast[string](data))
|
||||
if not isNil(m.streamHandler):
|
||||
await m.streamHandler(newConnection(channel))
|
||||
of MessageType.MsgIn, MessageType.MsgOut:
|
||||
await channel.pushTo(data)
|
||||
of MessageType.CloseIn, MessageType.CloseOut:
|
||||
await channel.closedByRemote()
|
||||
m.getChannelList(initiator).del(id)
|
||||
of MessageType.ResetIn, MessageType.ResetOut:
|
||||
await channel.resetByRemote()
|
||||
else: raise newMplexUnknownMsgError()
|
||||
finally:
|
||||
await m.connection.close()
|
||||
|
||||
proc newMplex*(conn: Connection,
|
||||
maxChanns: uint = MaxChannels): Mplex =
|
||||
new result
|
||||
result.connection = conn
|
||||
result.maxChannels = maxChanns
|
||||
result.remote = initTable[int, Channel]()
|
||||
result.local = initTable[int, Channel]()
|
||||
result.remote = initTable[uint, Channel]()
|
||||
result.local = initTable[uint, Channel]()
|
||||
|
||||
method newStream*(m: Mplex, name: string = ""): Future[Connection] {.async, gcsafe.} =
|
||||
let channel = await m.newStreamInternal()
|
||||
await m.connection.writeHeader(channel.id, MessageType.New, len(name))
|
||||
if name.len > 0:
|
||||
await m.connection.write(name)
|
||||
await m.connection.writeMsg(channel.id, MessageType.New, cast[seq[byte]](toSeq(name.items)))
|
||||
result = newConnection(channel)
|
||||
|
||||
method close*(m: Mplex) {.async, gcsafe.} =
|
||||
|
|
|
@ -12,65 +12,109 @@ import ../libp2p/connection,
|
|||
../libp2p/muxers/mplex/types,
|
||||
../libp2p/muxers/mplex/channel
|
||||
|
||||
type
|
||||
TestEncodeStream = ref object of LPStream
|
||||
handler*: proc(data: seq[byte])
|
||||
|
||||
method write*(s: TestEncodeStream,
|
||||
msg: seq[byte],
|
||||
msglen = -1):
|
||||
Future[void] {.gcsafe.} =
|
||||
s.handler(msg)
|
||||
|
||||
proc newTestEncodeStream(handler: proc(data: seq[byte])): TestEncodeStream =
|
||||
new result
|
||||
result.handler = handler
|
||||
|
||||
type
|
||||
TestDecodeStream = ref object of LPStream
|
||||
handler*: proc(data: seq[byte])
|
||||
step*: int
|
||||
msg*: seq[byte]
|
||||
|
||||
method readExactly*(s: TestDecodeStream,
|
||||
pbytes: pointer,
|
||||
nbytes: int): Future[void] {.async, gcsafe.} =
|
||||
let buff: seq[byte] = s.msg
|
||||
copyMem(pbytes, unsafeAddr buff[s.step], nbytes)
|
||||
s.step += nbytes
|
||||
|
||||
proc newTestDecodeStream(): TestDecodeStream =
|
||||
new result
|
||||
result.step = 0
|
||||
result.msg = fromHex("8801023137")
|
||||
|
||||
suite "Mplex":
|
||||
test "encode header":
|
||||
test "encode header with channel id 0":
|
||||
proc testEncodeHeader(): Future[bool] {.async.} =
|
||||
proc encHandler(msg: seq[byte]) =
|
||||
check msg == fromHex("886f04")
|
||||
proc encHandler(msg: seq[byte]) {.async.} =
|
||||
check msg == fromHex("000873747265616d2031")
|
||||
|
||||
let conn = newConnection(newTestEncodeStream(encHandler))
|
||||
await conn.writeHeader(1777, MessageType.New, 4)
|
||||
let stream = newBufferStream(encHandler)
|
||||
let conn = newConnection(stream)
|
||||
await conn.writeMsg(0, MessageType.New, cast[seq[byte]](toSeq("stream 1".items)))
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(testEncodeHeader()) == true
|
||||
|
||||
test "decode header":
|
||||
proc testDecodeHeader(): Future[bool] {.async.} =
|
||||
let conn = newConnection(newTestDecodeStream())
|
||||
let (id, msgType) = await conn.readHeader()
|
||||
test "encode header with channel id other than 0":
|
||||
proc testEncodeHeader(): Future[bool] {.async.} =
|
||||
proc encHandler(msg: seq[byte]) {.async.} =
|
||||
check msg == fromHex("88010873747265616d2031")
|
||||
|
||||
let stream = newBufferStream(encHandler)
|
||||
let conn = newConnection(stream)
|
||||
await conn.writeMsg(17, MessageType.New, cast[seq[byte]](toSeq("stream 1".items)))
|
||||
result = true
|
||||
|
||||
check id == 17
|
||||
check:
|
||||
waitFor(testEncodeHeader()) == true
|
||||
|
||||
test "encode header and body with channel id 0":
|
||||
proc testEncodeHeaderBody(): Future[bool] {.async.} =
|
||||
var step = 0
|
||||
proc encHandler(msg: seq[byte]) {.async.} =
|
||||
check msg == fromHex("020873747265616d2031")
|
||||
|
||||
let stream = newBufferStream(encHandler)
|
||||
let conn = newConnection(stream)
|
||||
await conn.writeMsg(0, MessageType.MsgOut, cast[seq[byte]](toSeq("stream 1".items)))
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(testEncodeHeaderBody()) == true
|
||||
|
||||
test "encode header and body with channel id other than 0":
|
||||
proc testEncodeHeaderBody(): Future[bool] {.async.} =
|
||||
var step = 0
|
||||
proc encHandler(msg: seq[byte]) {.async.} =
|
||||
check msg == fromHex("8a010873747265616d2031")
|
||||
|
||||
let stream = newBufferStream(encHandler)
|
||||
let conn = newConnection(stream)
|
||||
await conn.writeMsg(17, MessageType.MsgOut, cast[seq[byte]](toSeq("stream 1".items)))
|
||||
await conn.close()
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(testEncodeHeaderBody()) == true
|
||||
|
||||
test "decode header with channel id 0":
|
||||
proc testDecodeHeader(): Future[bool] {.async.} =
|
||||
proc encHandler(msg: seq[byte]) {.async.} = discard
|
||||
let stream = newBufferStream(encHandler)
|
||||
let conn = newConnection(stream)
|
||||
await stream.pushTo(fromHex("000873747265616d2031"))
|
||||
let (id, msgType, data) = await conn.readMsg()
|
||||
|
||||
check id == 0
|
||||
check msgType == MessageType.New
|
||||
let data = await conn.readLp()
|
||||
check cast[string](data) == "17"
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(testDecodeHeader()) == true
|
||||
|
||||
|
||||
test "decode header and body with channel id 0":
|
||||
proc testDecodeHeader(): Future[bool] {.async.} =
|
||||
proc encHandler(msg: seq[byte]) {.async.} = discard
|
||||
let stream = newBufferStream(encHandler)
|
||||
let conn = newConnection(stream)
|
||||
await stream.pushTo(fromHex("021668656C6C6F2066726F6D206368616E6E656C20302121"))
|
||||
let (id, msgType, data) = await conn.readMsg()
|
||||
|
||||
check id == 0
|
||||
check msgType == MessageType.MsgOut
|
||||
check cast[string](data) == "hello from channel 0!!"
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(testDecodeHeader()) == true
|
||||
|
||||
test "decode header and body with channel id other than 0":
|
||||
proc testDecodeHeader(): Future[bool] {.async.} =
|
||||
proc encHandler(msg: seq[byte]) {.async.} = discard
|
||||
let stream = newBufferStream(encHandler)
|
||||
let conn = newConnection(stream)
|
||||
await stream.pushTo(fromHex("8a011668656C6C6F2066726F6D206368616E6E656C20302121"))
|
||||
let (id, msgType, data) = await conn.readMsg()
|
||||
|
||||
check id == 17
|
||||
check msgType == MessageType.MsgOut
|
||||
check cast[string](data) == "hello from channel 0!!"
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(testDecodeHeader()) == true
|
||||
|
||||
test "e2e - read/write initiator":
|
||||
proc testNewStream(): Future[bool] {.async.} =
|
||||
let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53380")
|
||||
|
@ -92,10 +136,11 @@ suite "Mplex":
|
|||
|
||||
let mplexDial = newMplex(conn)
|
||||
let dialFut = mplexDial.handle()
|
||||
let stream = await mplexDial.newStream()
|
||||
check cast[string](await stream.readLp()) == "Hello from stream!"
|
||||
let stream = await mplexDial.newStream("DIALER")
|
||||
let msg = cast[string](await stream.readLp())
|
||||
check msg == "Hello from stream!"
|
||||
await conn.close()
|
||||
await dialFut
|
||||
# await dialFut
|
||||
result = true
|
||||
|
||||
check:
|
||||
|
@ -122,11 +167,9 @@ suite "Mplex":
|
|||
let conn = await transport2.dial(ma)
|
||||
|
||||
let mplexDial = newMplex(conn)
|
||||
let dialFut = mplexDial.handle()
|
||||
let stream = await mplexDial.newStream()
|
||||
await stream.writeLp("Hello from stream!")
|
||||
await conn.close()
|
||||
await dialFut
|
||||
result = true
|
||||
|
||||
check:
|
||||
|
@ -136,16 +179,13 @@ suite "Mplex":
|
|||
proc testNewStream(): Future[bool] {.async.} =
|
||||
let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53382")
|
||||
|
||||
var count = 0
|
||||
var completionFut: Future[void] = newFuture[void]()
|
||||
var count = 1
|
||||
proc connHandler(conn: Connection) {.async, gcsafe.} =
|
||||
proc handleMplexListen(stream: Connection) {.async, gcsafe.} =
|
||||
let msg = await stream.readLp()
|
||||
check cast[string](msg) == &"Hello from stream {count}!"
|
||||
check cast[string](msg) == &"stream {count}!"
|
||||
count.inc
|
||||
await stream.close()
|
||||
if count == 11:
|
||||
completionFut.complete()
|
||||
|
||||
let mplexListen = newMplex(conn)
|
||||
mplexListen.streamHandler = handleMplexListen
|
||||
|
@ -158,18 +198,12 @@ suite "Mplex":
|
|||
let conn = await transport2.dial(ma)
|
||||
|
||||
let mplexDial = newMplex(conn)
|
||||
asyncCheck mplexDial.handle()
|
||||
|
||||
for i in 0..10:
|
||||
for i in 1..<10:
|
||||
let stream = await mplexDial.newStream()
|
||||
await stream.writeLp(&"Hello from stream {i}!")
|
||||
|
||||
await completionFut
|
||||
# closing the connection doesn't transfer all the data
|
||||
# this seems to be a bug in chronos
|
||||
# await conn.close()
|
||||
check count == 11
|
||||
await stream.writeLp(&"stream {i}!")
|
||||
await stream.close()
|
||||
|
||||
await conn.close()
|
||||
result = true
|
||||
|
||||
check:
|
||||
|
@ -177,7 +211,8 @@ suite "Mplex":
|
|||
|
||||
test "half closed - channel should close for write":
|
||||
proc testClosedForWrite(): Future[void] {.async.} =
|
||||
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
|
||||
await chann.close()
|
||||
await chann.write("Hello")
|
||||
|
||||
|
@ -186,7 +221,8 @@ suite "Mplex":
|
|||
|
||||
test "half closed - channel should close for read":
|
||||
proc testClosedForRead(): Future[void] {.async.} =
|
||||
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
|
||||
await chann.closedByRemote()
|
||||
asyncDiscard chann.read()
|
||||
|
||||
|
@ -195,7 +231,8 @@ suite "Mplex":
|
|||
|
||||
test "half closed - channel should close for read after eof":
|
||||
proc testClosedForRead(): Future[void] {.async.} =
|
||||
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
|
||||
|
||||
await chann.pushTo(cast[seq[byte]](toSeq("Hello!".items)))
|
||||
await chann.close()
|
||||
|
@ -207,7 +244,8 @@ suite "Mplex":
|
|||
|
||||
test "reset - channel should fail reading":
|
||||
proc testResetRead(): Future[void] {.async.} =
|
||||
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
|
||||
await chann.reset()
|
||||
asyncDiscard chann.read()
|
||||
|
||||
|
@ -216,7 +254,8 @@ suite "Mplex":
|
|||
|
||||
test "reset - channel should fail writing":
|
||||
proc testResetWrite(): Future[void] {.async.} =
|
||||
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
|
||||
await chann.reset()
|
||||
asyncDiscard chann.read()
|
||||
|
||||
|
@ -225,7 +264,8 @@ suite "Mplex":
|
|||
|
||||
test "should not allow pushing data to channel when remote end closed":
|
||||
proc testResetWrite(): Future[void] {.async.} =
|
||||
let chann = newChannel(1, newConnection(new LPStream), true)
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
|
||||
await chann.closedByRemote()
|
||||
await chann.pushTo(@[byte(1)])
|
||||
|
||||
|
|
Loading…
Reference in New Issue