Feat/conn cleanup (#41)
Backporting proper connection cleanup from #36 to align with latest chronos changes. * add close event * use proper varint encoding * add proper channel cleanup in mplex * add connection cleanup in secio * tidy up * add dollar operator * fix tests * don't close connections prematurely * handle closing streams properly * misc * implement address filtering logic * adding pipe tests * don't use gcsafe if not needed * misc * proper connection cleanup and stream muxing * re-enable pubsub tests
This commit is contained in:
parent
1df16bdbce
commit
903e79ede1
|
@ -7,7 +7,7 @@
|
|||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
import chronos, options, chronicles
|
||||
import chronos, chronicles
|
||||
import peerinfo,
|
||||
multiaddress,
|
||||
stream/lpstream,
|
||||
|
@ -26,15 +26,28 @@ type
|
|||
InvalidVarintException = object of LPStreamError
|
||||
|
||||
proc newInvalidVarintException*(): ref InvalidVarintException =
|
||||
result = newException(InvalidVarintException, "unable to prase varint")
|
||||
newException(InvalidVarintException, "unable to prase varint")
|
||||
|
||||
proc newConnection*(stream: LPStream): Connection =
|
||||
## create a new Connection for the specified async reader/writer
|
||||
new result
|
||||
result.stream = stream
|
||||
result.closeEvent = newAsyncEvent()
|
||||
|
||||
# bind stream's close event to connection's close
|
||||
# to ensure correct close propagation
|
||||
let this = result
|
||||
if not isNil(result.stream.closeEvent):
|
||||
result.stream.closeEvent.wait().
|
||||
addCallback(
|
||||
proc (udata: pointer) =
|
||||
if not this.closed:
|
||||
trace "closing this connection because wrapped stream closed"
|
||||
asyncCheck this.close()
|
||||
)
|
||||
|
||||
method read*(s: Connection, n = -1): Future[seq[byte]] {.gcsafe.} =
|
||||
result = s.stream.read(n)
|
||||
s.stream.read(n)
|
||||
|
||||
method readExactly*(s: Connection,
|
||||
pbytes: pointer,
|
||||
|
@ -44,13 +57,13 @@ method readExactly*(s: Connection,
|
|||
|
||||
method readLine*(s: Connection,
|
||||
limit = 0,
|
||||
sep = "\r\n"):
|
||||
sep = "\r\n"):
|
||||
Future[string] {.gcsafe.} =
|
||||
s.stream.readLine(limit, sep)
|
||||
|
||||
method readOnce*(s: Connection,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
nbytes: int):
|
||||
Future[int] {.gcsafe.} =
|
||||
s.stream.readOnce(pbytes, nbytes)
|
||||
|
||||
|
@ -61,15 +74,15 @@ method readUntil*(s: Connection,
|
|||
Future[int] {.gcsafe.} =
|
||||
s.stream.readUntil(pbytes, nbytes, sep)
|
||||
|
||||
method write*(s: Connection,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
method write*(s: Connection,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[void] {.gcsafe.} =
|
||||
s.stream.write(pbytes, nbytes)
|
||||
|
||||
method write*(s: Connection,
|
||||
msg: string,
|
||||
msglen = -1):
|
||||
method write*(s: Connection,
|
||||
msg: string,
|
||||
msglen = -1):
|
||||
Future[void] {.gcsafe.} =
|
||||
s.stream.write(msg, msglen)
|
||||
|
||||
|
@ -79,9 +92,20 @@ method write*(s: Connection,
|
|||
Future[void] {.gcsafe.} =
|
||||
s.stream.write(msg, msglen)
|
||||
|
||||
method closed*(s: Connection): bool =
|
||||
if isNil(s.stream):
|
||||
return false
|
||||
|
||||
result = s.stream.closed
|
||||
|
||||
method close*(s: Connection) {.async, gcsafe.} =
|
||||
await s.stream.close()
|
||||
s.closed = true
|
||||
trace "closing connection"
|
||||
if not s.closed:
|
||||
if not isNil(s.stream) and not s.stream.closed:
|
||||
await s.stream.close()
|
||||
s.closeEvent.fire()
|
||||
s.isClosed = true
|
||||
trace "connection closed", closed = s.closed
|
||||
|
||||
proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} =
|
||||
## read lenght prefixed msg
|
||||
|
@ -100,21 +124,23 @@ proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} =
|
|||
raise newInvalidVarintException()
|
||||
result.setLen(size)
|
||||
if size > 0.uint:
|
||||
trace "reading exact bytes from stream", size = size
|
||||
await s.readExactly(addr result[0], int(size))
|
||||
except LPStreamIncompleteError, LPStreamReadError:
|
||||
trace "remote connection closed", exc = getCurrentExceptionMsg()
|
||||
except LPStreamIncompleteError as exc:
|
||||
trace "remote connection ended unexpectedly", exc = exc.msg
|
||||
except LPStreamReadError as exc:
|
||||
trace "couldn't read from stream", exc = exc.msg
|
||||
|
||||
proc writeLp*(s: Connection, msg: string | seq[byte]): Future[void] {.gcsafe.} =
|
||||
## write lenght prefixed
|
||||
var buf = initVBuffer()
|
||||
buf.writeSeq(msg)
|
||||
buf.finish()
|
||||
result = s.write(buf.buffer)
|
||||
s.write(buf.buffer)
|
||||
|
||||
method getObservedAddrs*(c: Connection): Future[MultiAddress] {.base, async, gcsafe.} =
|
||||
## get resolved multiaddresses for the connection
|
||||
result = c.observedAddrs
|
||||
|
||||
proc `$`*(conn: Connection): string =
|
||||
if conn.peerInfo.peerId.isSome:
|
||||
result = $(conn.peerInfo.peerId.get())
|
||||
result = $(conn.peerInfo)
|
||||
|
|
|
@ -855,7 +855,7 @@ proc connect*(api: DaemonAPI, peer: PeerID,
|
|||
timeout))
|
||||
pb.withMessage() do:
|
||||
discard
|
||||
finally:
|
||||
except:
|
||||
await api.closeConnection(transp)
|
||||
|
||||
proc disconnect*(api: DaemonAPI, peer: PeerID) {.async.} =
|
||||
|
|
|
@ -7,12 +7,12 @@
|
|||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
import sequtils, strutils, strformat
|
||||
import strutils
|
||||
import chronos, chronicles
|
||||
import connection,
|
||||
varint,
|
||||
vbuffer,
|
||||
protocols/protocol
|
||||
protocols/protocol,
|
||||
stream/lpstream
|
||||
|
||||
logScope:
|
||||
topic = "Multistream"
|
||||
|
@ -56,16 +56,16 @@ proc select*(m: MultisteamSelect,
|
|||
trace "selecting proto", proto = proto
|
||||
await conn.writeLp((proto[0] & "\n")) # select proto
|
||||
|
||||
result = cast[string](await conn.readLp()) # read ms header
|
||||
result = cast[string]((await conn.readLp())) # read ms header
|
||||
result.removeSuffix("\n")
|
||||
if result != Codec:
|
||||
trace "handshake failed", codec = result
|
||||
trace "handshake failed", codec = result.toHex()
|
||||
return ""
|
||||
|
||||
if proto.len() == 0: # no protocols, must be a handshake call
|
||||
return
|
||||
|
||||
result = cast[string](await conn.readLp()) # read the first proto
|
||||
result = cast[string]((await conn.readLp())) # read the first proto
|
||||
trace "reading first requested proto"
|
||||
result.removeSuffix("\n")
|
||||
if result == proto[0]:
|
||||
|
@ -76,7 +76,7 @@ proc select*(m: MultisteamSelect,
|
|||
trace "selecting one of several protos"
|
||||
for p in proto[1..<proto.len()]:
|
||||
await conn.writeLp((p & "\n")) # select proto
|
||||
result = cast[string](await conn.readLp()) # read the first proto
|
||||
result = cast[string]((await conn.readLp())) # read the first proto
|
||||
result.removeSuffix("\n")
|
||||
if result == p:
|
||||
trace "selected protocol", protocol = result
|
||||
|
@ -102,7 +102,7 @@ proc list*(m: MultisteamSelect,
|
|||
await conn.write(m.ls) # send ls
|
||||
|
||||
var list = newSeq[string]()
|
||||
let ms = cast[string](await conn.readLp())
|
||||
let ms = cast[string]((await conn.readLp()))
|
||||
for s in ms.split("\n"):
|
||||
if s.len() > 0:
|
||||
list.add(s)
|
||||
|
@ -111,8 +111,10 @@ proc list*(m: MultisteamSelect,
|
|||
|
||||
proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} =
|
||||
trace "handle: starting multistream handling"
|
||||
while not conn.closed:
|
||||
var ms = cast[string](await conn.readLp())
|
||||
try:
|
||||
while not conn.closed:
|
||||
await sleepAsync(1.millis)
|
||||
var ms = cast[string]((await conn.readLp()))
|
||||
ms.removeSuffix("\n")
|
||||
|
||||
trace "handle: got request for ", ms
|
||||
|
@ -142,11 +144,15 @@ proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} =
|
|||
try:
|
||||
await h.protocol.handler(conn, ms)
|
||||
return
|
||||
except Exception as exc:
|
||||
warn "exception while handling ", msg = exc.msg
|
||||
except CatchableError as exc:
|
||||
warn "exception while handling", msg = exc.msg
|
||||
return
|
||||
warn "no handlers for ", protocol = ms
|
||||
await conn.write(m.na)
|
||||
except CatchableError as exc:
|
||||
trace "exception occured", exc = exc.msg
|
||||
finally:
|
||||
trace "leaving multistream loop"
|
||||
|
||||
proc addHandler*[T: LPProtocol](m: MultisteamSelect,
|
||||
codec: string,
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
import chronos, options, sequtils, strformat
|
||||
import chronos, options
|
||||
import nimcrypto/utils, chronicles
|
||||
import types,
|
||||
../../connection,
|
||||
|
@ -29,31 +29,33 @@ proc readMplexVarint(conn: Connection): Future[Option[uint]] {.async, gcsafe.} =
|
|||
varint: uint
|
||||
length: int
|
||||
res: VarintStatus
|
||||
var buffer = newSeq[byte](10)
|
||||
buffer = newSeq[byte](10)
|
||||
|
||||
result = none(uint)
|
||||
try:
|
||||
for i in 0..<len(buffer):
|
||||
await conn.readExactly(addr buffer[i], 1)
|
||||
res = LP.getUVarint(buffer.toOpenArray(0, i), length, varint)
|
||||
if res == VarintStatus.Success:
|
||||
return some(varint)
|
||||
if not conn.closed:
|
||||
await conn.readExactly(addr buffer[i], 1)
|
||||
res = PB.getUVarint(buffer.toOpenArray(0, i), length, varint)
|
||||
if res == VarintStatus.Success:
|
||||
return some(varint)
|
||||
if res != VarintStatus.Success:
|
||||
raise newInvalidVarintException()
|
||||
except LPStreamIncompleteError:
|
||||
trace "unable to read varint", exc = getCurrentExceptionMsg()
|
||||
except LPStreamIncompleteError as exc:
|
||||
trace "unable to read varint", exc = exc.msg
|
||||
|
||||
proc readMsg*(conn: Connection): Future[Option[Msg]] {.async, gcsafe.} =
|
||||
let headerVarint = await conn.readMplexVarint()
|
||||
if headerVarint.isNone:
|
||||
return
|
||||
|
||||
trace "readMsg: read header varint ", varint = headerVarint
|
||||
trace "read header varint", varint = headerVarint
|
||||
|
||||
let dataLenVarint = await conn.readMplexVarint()
|
||||
var data: seq[byte]
|
||||
if dataLenVarint.isSome and dataLenVarint.get() > 0.uint:
|
||||
trace "readMsg: read size varint ", varint = dataLenVarint
|
||||
data = await conn.read(dataLenVarint.get().int)
|
||||
trace "read size varint", varint = dataLenVarint
|
||||
|
||||
let header = headerVarint.get()
|
||||
result = some((header shr 3, MessageType(header and 0x7), data))
|
||||
|
@ -64,11 +66,13 @@ proc writeMsg*(conn: Connection,
|
|||
data: seq[byte] = @[]) {.async, gcsafe.} =
|
||||
## write lenght prefixed
|
||||
var buf = initVBuffer()
|
||||
let header = (id shl 3 or ord(msgType).uint)
|
||||
buf.writeVarint(id shl 3 or ord(msgType).uint)
|
||||
buf.writeVarint(data.len().uint) # size should be always sent
|
||||
buf.writePBVarint(id shl 3 or ord(msgType).uint)
|
||||
buf.writePBVarint(data.len().uint) # size should be always sent
|
||||
buf.finish()
|
||||
await conn.write(buf.buffer & data)
|
||||
try:
|
||||
await conn.write(buf.buffer & data)
|
||||
except LPStreamIncompleteError as exc:
|
||||
trace "unable to send message", exc = exc.msg
|
||||
|
||||
proc writeMsg*(conn: Connection,
|
||||
id: uint,
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
import strformat
|
||||
import chronos, chronicles
|
||||
import types,
|
||||
coder,
|
||||
|
@ -52,99 +51,110 @@ proc newChannel*(id: uint,
|
|||
result.asyncLock = newAsyncLock()
|
||||
|
||||
let chan = result
|
||||
proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} =
|
||||
proc writeHandler(data: seq[byte]): Future[void] {.async.} =
|
||||
# writes should happen in sequence
|
||||
await chan.asyncLock.acquire()
|
||||
trace "writeHandler: sending data ", data = data.toHex(), id = chan.id
|
||||
trace "sending data ", data = data.toHex(),
|
||||
id = chan.id,
|
||||
initiator = chan.initiator
|
||||
|
||||
await conn.writeMsg(chan.id, chan.msgCode, data) # write header
|
||||
chan.asyncLock.release()
|
||||
|
||||
result.initBufferStream(writeHandler, size)
|
||||
|
||||
proc closeMessage(s: LPChannel) {.async, gcsafe.} =
|
||||
proc closeMessage(s: LPChannel) {.async.} =
|
||||
await s.conn.writeMsg(s.id, s.closeCode) # write header
|
||||
|
||||
proc closed*(s: LPChannel): bool =
|
||||
s.closedLocal and s.closedLocal
|
||||
|
||||
proc closedByRemote*(s: LPChannel) {.async.} =
|
||||
s.closedRemote = true
|
||||
|
||||
proc cleanUp*(s: LPChannel): Future[void] =
|
||||
# method which calls the underlying buffer's `close`
|
||||
# method used instead of `close` since it's overloaded to
|
||||
# simulate half-closed streams
|
||||
result = procCall close(BufferStream(s))
|
||||
|
||||
proc open*(s: LPChannel): Future[void] =
|
||||
s.conn.writeMsg(s.id, MessageType.New, s.name)
|
||||
|
||||
method close*(s: LPChannel) {.async, gcsafe.} =
|
||||
s.closedLocal = true
|
||||
await s.closeMessage()
|
||||
|
||||
proc resetMessage(s: LPChannel) {.async, gcsafe.} =
|
||||
proc resetMessage(s: LPChannel) {.async.} =
|
||||
await s.conn.writeMsg(s.id, s.resetCode)
|
||||
|
||||
proc resetByRemote*(s: LPChannel) {.async, gcsafe.} =
|
||||
proc resetByRemote*(s: LPChannel) {.async.} =
|
||||
await allFutures(s.close(), s.closedByRemote())
|
||||
s.isReset = true
|
||||
|
||||
proc reset*(s: LPChannel) {.async.} =
|
||||
await allFutures(s.resetMessage(), s.resetByRemote())
|
||||
|
||||
proc isReadEof(s: LPChannel): bool =
|
||||
bool((s.closedRemote or s.closedLocal) and s.len() < 1)
|
||||
method closed*(s: LPChannel): bool =
|
||||
result = s.closedRemote and s.len == 0
|
||||
|
||||
proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] {.gcsafe.} =
|
||||
if s.closedRemote:
|
||||
proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] =
|
||||
if s.closedRemote or s.isReset:
|
||||
raise newLPStreamClosedError()
|
||||
trace "pushing data to channel", data = data.toHex(),
|
||||
id = s.id,
|
||||
initiator = s.initiator
|
||||
|
||||
result = procCall pushTo(BufferStream(s), data)
|
||||
|
||||
method read*(s: LPChannel, n = -1): Future[seq[byte]] {.gcsafe.} =
|
||||
if s.isReadEof():
|
||||
method read*(s: LPChannel, n = -1): Future[seq[byte]] =
|
||||
if s.closed or s.isReset:
|
||||
raise newLPStreamClosedError()
|
||||
|
||||
result = procCall read(BufferStream(s), n)
|
||||
|
||||
method readExactly*(s: LPChannel,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[void] {.gcsafe.} =
|
||||
if s.isReadEof():
|
||||
method readExactly*(s: LPChannel,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[void] =
|
||||
if s.closed or s.isReset:
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall readExactly(BufferStream(s), pbytes, nbytes)
|
||||
|
||||
method readLine*(s: LPChannel,
|
||||
limit = 0,
|
||||
sep = "\r\n"):
|
||||
Future[string] {.gcsafe.} =
|
||||
if s.isReadEof():
|
||||
Future[string] =
|
||||
if s.closed or s.isReset:
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall readLine(BufferStream(s), limit, sep)
|
||||
|
||||
method readOnce*(s: LPChannel,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[int] {.gcsafe.} =
|
||||
if s.isReadEof():
|
||||
Future[int] =
|
||||
if s.closed or s.isReset:
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall readOnce(BufferStream(s), pbytes, nbytes)
|
||||
|
||||
method readUntil*(s: LPChannel,
|
||||
pbytes: pointer, nbytes: int,
|
||||
sep: seq[byte]):
|
||||
Future[int] {.gcsafe.} =
|
||||
if s.isReadEof():
|
||||
Future[int] =
|
||||
if s.closed or s.isReset:
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall readOnce(BufferStream(s), pbytes, nbytes)
|
||||
|
||||
method write*(s: LPChannel,
|
||||
pbytes: pointer,
|
||||
nbytes: int): Future[void] {.gcsafe.} =
|
||||
if s.closedLocal:
|
||||
nbytes: int): Future[void] =
|
||||
if s.closedLocal or s.isReset:
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall write(BufferStream(s), pbytes, nbytes)
|
||||
|
||||
method write*(s: LPChannel, msg: string, msglen = -1) {.async, gcsafe.} =
|
||||
if s.closedLocal:
|
||||
method write*(s: LPChannel, msg: string, msglen = -1) {.async.} =
|
||||
if s.closedLocal or s.isReset:
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall write(BufferStream(s), msg, msglen)
|
||||
|
||||
method write*(s: LPChannel, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
|
||||
if s.closedLocal:
|
||||
method write*(s: LPChannel, msg: seq[byte], msglen = -1) {.async.} =
|
||||
if s.closedLocal or s.isReset:
|
||||
raise newLPStreamClosedError()
|
||||
result = procCall write(BufferStream(s), msg, msglen)
|
||||
|
|
|
@ -11,16 +11,14 @@
|
|||
## Timeouts and message limits are still missing
|
||||
## they need to be added ASAP
|
||||
|
||||
import tables, sequtils, options, strformat
|
||||
import tables, sequtils, options
|
||||
import chronos, chronicles
|
||||
import coder, types, lpchannel,
|
||||
../muxer,
|
||||
../../varint,
|
||||
import ../muxer,
|
||||
../../connection,
|
||||
../../vbuffer,
|
||||
../../protocols/protocol,
|
||||
../../stream/bufferstream,
|
||||
../../stream/lpstream
|
||||
../../stream/lpstream,
|
||||
coder,
|
||||
types,
|
||||
lpchannel
|
||||
|
||||
logScope:
|
||||
topic = "Mplex"
|
||||
|
@ -34,9 +32,11 @@ type
|
|||
|
||||
proc getChannelList(m: Mplex, initiator: bool): var Table[uint, LPChannel] =
|
||||
if initiator:
|
||||
result = m.remote
|
||||
else:
|
||||
trace "picking local channels", initiator = initiator
|
||||
result = m.local
|
||||
else:
|
||||
trace "picking remote channels", initiator = initiator
|
||||
result = m.remote
|
||||
|
||||
proc newStreamInternal*(m: Mplex,
|
||||
initiator: bool = true,
|
||||
|
@ -45,17 +45,28 @@ proc newStreamInternal*(m: Mplex,
|
|||
Future[LPChannel] {.async, gcsafe.} =
|
||||
## create new channel/stream
|
||||
let id = if initiator: m.currentId.inc(); m.currentId else: chanId
|
||||
trace "creating new channel", channelId = id, initiator = initiator
|
||||
result = newChannel(id, m.connection, initiator, name)
|
||||
m.getChannelList(initiator)[id] = result
|
||||
|
||||
proc cleanupChann(m: Mplex, chann: LPChannel, initiator: bool) {.async, inline.} =
|
||||
## call the channel's `close` to signal the
|
||||
## remote that the channel is closing
|
||||
if not isNil(chann) and not chann.closed:
|
||||
await chann.close()
|
||||
await chann.cleanUp()
|
||||
m.getChannelList(initiator).del(chann.id)
|
||||
trace "cleaned up channel", id = chann.id
|
||||
|
||||
method handle*(m: Mplex) {.async, gcsafe.} =
|
||||
trace "starting mplex main loop"
|
||||
try:
|
||||
while not m.connection.closed:
|
||||
trace "waiting for data"
|
||||
let msg = await m.connection.readMsg()
|
||||
if msg.isNone:
|
||||
# TODO: allow poll with timeout to avoid using `sleepAsync`
|
||||
await sleepAsync(10.millis)
|
||||
await sleepAsync(1.millis)
|
||||
continue
|
||||
|
||||
let (id, msgType, data) = msg.get()
|
||||
|
@ -63,8 +74,11 @@ method handle*(m: Mplex) {.async, gcsafe.} =
|
|||
var channel: LPChannel
|
||||
if MessageType(msgType) != MessageType.New:
|
||||
let channels = m.getChannelList(initiator)
|
||||
if not channels.contains(id):
|
||||
trace "handle: Channel with id and msg type ", id = id, msg = msgType
|
||||
if id notin channels:
|
||||
trace "Channel not found, skipping", id = id,
|
||||
initiator = initiator,
|
||||
msg = msgType
|
||||
await sleepAsync(1.millis)
|
||||
continue
|
||||
channel = channels[id]
|
||||
|
||||
|
@ -72,36 +86,44 @@ method handle*(m: Mplex) {.async, gcsafe.} =
|
|||
of MessageType.New:
|
||||
let name = cast[string](data)
|
||||
channel = await m.newStreamInternal(false, id, name)
|
||||
trace "handle: created channel ", id = id, name = name
|
||||
trace "created channel", id = id, name = name, inititator = true
|
||||
if not isNil(m.streamHandler):
|
||||
let stream = newConnection(channel)
|
||||
stream.peerInfo = m.connection.peerInfo
|
||||
let handlerFut = m.streamHandler(stream)
|
||||
|
||||
# channel cleanup routine
|
||||
proc cleanUpChan(udata: pointer) {.gcsafe.} =
|
||||
if handlerFut.finished:
|
||||
channel.close().addCallback(
|
||||
proc(udata: pointer) =
|
||||
channel.cleanUp()
|
||||
.addCallback(proc(udata: pointer) =
|
||||
trace "handle: cleaned up channel ", id = id))
|
||||
handlerFut.addCallback(cleanUpChan)
|
||||
# cleanup channel once handler is finished
|
||||
# stream.closeEvent.wait().addCallback(
|
||||
# proc(udata: pointer) =
|
||||
# asyncCheck cleanupChann(m, channel, initiator))
|
||||
|
||||
asyncCheck m.streamHandler(stream)
|
||||
|
||||
continue
|
||||
of MessageType.MsgIn, MessageType.MsgOut:
|
||||
trace "handle: pushing data to channel ", id = id, msgType = msgType
|
||||
trace "pushing data to channel", id = id,
|
||||
initiator = initiator,
|
||||
msgType = msgType
|
||||
|
||||
await channel.pushTo(data)
|
||||
of MessageType.CloseIn, MessageType.CloseOut:
|
||||
trace "handle: closing channel ", id = id, msgType = msgType
|
||||
trace "closing channel", id = id,
|
||||
initiator = initiator,
|
||||
msgType = msgType
|
||||
|
||||
await channel.closedByRemote()
|
||||
m.getChannelList(initiator).del(id)
|
||||
of MessageType.ResetIn, MessageType.ResetOut:
|
||||
trace "handle: resetting channel ", id = id
|
||||
trace "resetting channel", id = id,
|
||||
initiator = initiator,
|
||||
msgType = msgType
|
||||
|
||||
await channel.resetByRemote()
|
||||
m.getChannelList(initiator).del(id)
|
||||
break
|
||||
except:
|
||||
error "exception occurred", exception = getCurrentExceptionMsg()
|
||||
except CatchableError as exc:
|
||||
trace "exception occurred", exception = exc.msg
|
||||
finally:
|
||||
trace "stopping mplex main loop"
|
||||
await m.connection.close()
|
||||
|
||||
proc newMplex*(conn: Connection,
|
||||
|
@ -112,13 +134,20 @@ proc newMplex*(conn: Connection,
|
|||
result.remote = initTable[uint, LPChannel]()
|
||||
result.local = initTable[uint, LPChannel]()
|
||||
|
||||
let m = result
|
||||
conn.closeEvent.wait().addCallback(
|
||||
proc(udata: pointer) =
|
||||
asyncCheck m.close()
|
||||
)
|
||||
|
||||
method newStream*(m: Mplex, name: string = ""): Future[Connection] {.async, gcsafe.} =
|
||||
let channel = await m.newStreamInternal()
|
||||
await m.connection.writeMsg(channel.id, MessageType.New, name)
|
||||
# TODO: open the channel (this should be lazy)
|
||||
await channel.open()
|
||||
result = newConnection(channel)
|
||||
result.peerInfo = m.connection.peerInfo
|
||||
|
||||
method close*(m: Mplex) {.async, gcsafe.} =
|
||||
await allFutures(@[allFutures(toSeq(m.remote.values).mapIt(it.close())),
|
||||
allFutures(toSeq(m.local.values).mapIt(it.close()))])
|
||||
m.connection.reset()
|
||||
trace "closing mplex muxer"
|
||||
await allFutures(@[allFutures(toSeq(m.remote.values).mapIt(it.reset())),
|
||||
allFutures(toSeq(m.local.values).mapIt(it.reset()))])
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
## those terms.
|
||||
|
||||
import chronos
|
||||
import ../../connection
|
||||
|
||||
const MaxMsgSize* = 1 shl 20 # 1mb
|
||||
const MaxChannels* = 1000
|
||||
|
|
|
@ -10,7 +10,27 @@
|
|||
import options
|
||||
import peer, multiaddress
|
||||
|
||||
type PeerInfo* = object of RootObj
|
||||
peerId*: Option[PeerID]
|
||||
addrs*: seq[MultiAddress]
|
||||
protocols*: seq[string]
|
||||
type
|
||||
PeerInfo* = object of RootObj
|
||||
peerId*: Option[PeerID]
|
||||
addrs*: seq[MultiAddress]
|
||||
protocols*: seq[string]
|
||||
|
||||
proc id*(p: PeerInfo): string =
|
||||
if p.peerId.isSome:
|
||||
result = p.peerId.get().pretty
|
||||
|
||||
proc `$`*(p: PeerInfo): string =
|
||||
if p.peerId.isSome:
|
||||
result.add("PeerID: ")
|
||||
result.add(p.id & "\n")
|
||||
|
||||
if p.addrs.len > 0:
|
||||
result.add("Peer Addrs: ")
|
||||
for a in p.addrs:
|
||||
result.add($a & "\n")
|
||||
|
||||
if p.protocols.len > 0:
|
||||
result.add("Protocols: ")
|
||||
for proto in p.protocols:
|
||||
result.add(proto & "\n")
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
import options, strformat
|
||||
import options
|
||||
import chronos, chronicles
|
||||
import ../protobuf/minprotobuf,
|
||||
../peerinfo,
|
||||
|
@ -115,14 +115,14 @@ method init*(p: Identify) =
|
|||
trace "handling identify request"
|
||||
var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs())
|
||||
await conn.writeLp(pb.buffer)
|
||||
# await conn.close() #TODO: investigate why this breaks
|
||||
|
||||
p.handler = handle
|
||||
p.codec = IdentifyCodec
|
||||
|
||||
proc identify*(p: Identify,
|
||||
conn: Connection,
|
||||
remotePeerInfo: PeerInfo):
|
||||
Future[IdentifyInfo] {.async.} =
|
||||
remotePeerInfo: PeerInfo): Future[IdentifyInfo] {.async, gcsafe.} =
|
||||
var message = await conn.readLp()
|
||||
if len(message) == 0:
|
||||
trace "identify: Invalid or empty message received!"
|
||||
|
@ -139,7 +139,7 @@ proc identify*(p: Identify,
|
|||
if peer != remotePeerInfo.peerId.get():
|
||||
trace "Peer ids don't match",
|
||||
remote = peer.pretty(),
|
||||
local = remotePeerInfo.peerId.get().pretty()
|
||||
local = remotePeerInfo.id
|
||||
|
||||
raise newException(IdentityNoMatchError,
|
||||
"Peer ids don't match")
|
||||
|
@ -149,5 +149,4 @@ proc identify*(p: Identify,
|
|||
proc push*(p: Identify, conn: Connection) {.async.} =
|
||||
await conn.write(IdentifyPushCodec)
|
||||
var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs())
|
||||
let length = pb.getLen()
|
||||
await conn.writeLp(pb.buffer)
|
||||
|
|
|
@ -8,9 +8,7 @@
|
|||
## those terms.
|
||||
|
||||
import chronos
|
||||
import ../connection,
|
||||
../peerinfo,
|
||||
../multiaddress
|
||||
import ../connection
|
||||
|
||||
type
|
||||
LPProtoHandler* = proc (conn: Connection,
|
||||
|
|
|
@ -14,6 +14,7 @@ import rpcmsg,
|
|||
../../peer,
|
||||
../../peerinfo,
|
||||
../../connection,
|
||||
../../stream/lpstream,
|
||||
../../crypto/crypto,
|
||||
../../protobuf/minprotobuf
|
||||
|
||||
|
@ -45,7 +46,7 @@ proc handle*(p: PubSubPeer) {.async, gcsafe.} =
|
|||
trace "Decoded msg from peer", peer = p.id, msg = msg
|
||||
await p.handler(p, @[msg])
|
||||
except:
|
||||
error "An exception occured while processing pubsub rpc requests", exc = getCurrentExceptionMsg()
|
||||
trace "An exception occured while processing pubsub rpc requests", exc = getCurrentExceptionMsg()
|
||||
finally:
|
||||
trace "closing connection to pubsub peer", peer = p.id
|
||||
await p.conn.close()
|
||||
|
|
|
@ -8,8 +8,7 @@
|
|||
## those terms.
|
||||
|
||||
import chronos
|
||||
import secure,
|
||||
../../connection
|
||||
import secure, ../../connection
|
||||
|
||||
const PlainTextCodec* = "/plaintext/1.0.0"
|
||||
|
||||
|
|
|
@ -6,10 +6,12 @@
|
|||
## at your option.
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
import options
|
||||
import chronos, chronicles
|
||||
import nimcrypto/[sysrand, hmac, sha2, sha, hash, rijndael, twofish, bcmode]
|
||||
import secure,
|
||||
../../connection,
|
||||
../../stream/lpstream,
|
||||
../../crypto/crypto,
|
||||
../../crypto/ecnist,
|
||||
../../protobuf/minprotobuf,
|
||||
|
@ -60,7 +62,6 @@ type
|
|||
ctxsha1: HMAC[sha1]
|
||||
|
||||
SecureConnection* = ref object of Connection
|
||||
conn*: Connection
|
||||
writerMac: SecureMac
|
||||
readerMac: SecureMac
|
||||
writerCoder: SecureCipher
|
||||
|
@ -176,13 +177,13 @@ proc readMessage*(sconn: SecureConnection): Future[seq[byte]] {.async.} =
|
|||
## Read message from channel secure connection ``sconn``.
|
||||
try:
|
||||
var buf = newSeq[byte](4)
|
||||
await sconn.conn.readExactly(addr buf[0], 4)
|
||||
await sconn.readExactly(addr buf[0], 4)
|
||||
let length = (int(buf[0]) shl 24) or (int(buf[1]) shl 16) or
|
||||
(int(buf[2]) shl 8) or (int(buf[3]))
|
||||
trace "Recieved message header", header = toHex(buf), length = length
|
||||
if length <= SecioMaxMessageSize:
|
||||
buf.setLen(length)
|
||||
await sconn.conn.readExactly(addr buf[0], length)
|
||||
await sconn.readExactly(addr buf[0], length)
|
||||
trace "Received message body", length = length,
|
||||
buffer = toHex(buf)
|
||||
if sconn.macCheckAndDecode(buf):
|
||||
|
@ -213,21 +214,27 @@ proc writeMessage*(sconn: SecureConnection, message: seq[byte]) {.async.} =
|
|||
msg[3] = byte(length and 0xFF)
|
||||
trace "Writing message", message = toHex(msg)
|
||||
try:
|
||||
await sconn.conn.write(msg)
|
||||
await sconn.write(msg)
|
||||
except AsyncStreamWriteError:
|
||||
trace "Could not write to connection"
|
||||
|
||||
proc newSecureConnection*(conn: Connection, hash: string, cipher: string,
|
||||
proc newSecureConnection*(conn: Connection,
|
||||
hash: string,
|
||||
cipher: string,
|
||||
secrets: Secret,
|
||||
order: int): SecureConnection =
|
||||
order: int,
|
||||
peerId: PeerID): SecureConnection =
|
||||
## Create new secure connection, using specified hash algorithm ``hash``,
|
||||
## cipher algorithm ``cipher``, stretched keys ``secrets`` and order
|
||||
## ``order``.
|
||||
new result
|
||||
|
||||
result.stream = conn
|
||||
result.closeEvent = newAsyncEvent()
|
||||
|
||||
let i0 = if order < 0: 1 else: 0
|
||||
let i1 = if order < 0: 0 else: 1
|
||||
|
||||
result.conn = conn
|
||||
trace "Writer credentials", mackey = toHex(secrets.macOpenArray(i0)),
|
||||
enckey = toHex(secrets.keyOpenArray(i0)),
|
||||
iv = toHex(secrets.ivOpenArray(i0))
|
||||
|
@ -241,6 +248,8 @@ proc newSecureConnection*(conn: Connection, hash: string, cipher: string,
|
|||
result.readerCoder.init(cipher, secrets.keyOpenArray(i1),
|
||||
secrets.ivOpenArray(i1))
|
||||
|
||||
result.peerInfo.peerId = some(peerId)
|
||||
|
||||
proc transactMessage(conn: Connection,
|
||||
msg: seq[byte]): Future[seq[byte]] {.async.} =
|
||||
var buf = newSeq[byte](4)
|
||||
|
@ -281,7 +290,6 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.}
|
|||
remoteHashes: string
|
||||
remotePeerId: PeerID
|
||||
localPeerId: PeerID
|
||||
ekey: PrivateKey
|
||||
localBytesPubkey = s.localPublicKey.getBytes()
|
||||
|
||||
if randomBytes(localNonce) != SecioNonceSize:
|
||||
|
@ -388,7 +396,8 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.}
|
|||
|
||||
# Perform Nonce exchange over encrypted channel.
|
||||
|
||||
result = newSecureConnection(conn, hash, cipher, keys, order)
|
||||
result = newSecureConnection(conn, hash, cipher, keys, order, remotePeerId)
|
||||
|
||||
await result.writeMessage(remoteNonce)
|
||||
var res = await result.readMessage()
|
||||
|
||||
|
@ -400,17 +409,21 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.}
|
|||
trace "Secure handshake succeeded"
|
||||
|
||||
proc readLoop(sconn: SecureConnection, stream: BufferStream) {.async.} =
|
||||
while not sconn.conn.closed:
|
||||
try:
|
||||
try:
|
||||
while not sconn.closed:
|
||||
let msg = await sconn.readMessage()
|
||||
await stream.pushTo(msg)
|
||||
except CatchableError as exc:
|
||||
trace "exception in secio", exc = exc.msg
|
||||
return
|
||||
finally:
|
||||
trace "ending secio readLoop"
|
||||
if msg.len > 0:
|
||||
await stream.pushTo(msg)
|
||||
|
||||
# tight loop, give a chance for other
|
||||
# stuff to run as well
|
||||
await sleepAsync(1.millis)
|
||||
except CatchableError as exc:
|
||||
trace "exception occured", exc = exc.msg
|
||||
finally:
|
||||
trace "ending secio readLoop", isclosed = sconn.closed()
|
||||
|
||||
proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async.} =
|
||||
proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
||||
var sconn = await s.handshake(conn)
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} =
|
||||
trace "sending encrypted bytes", bytes = data.toHex()
|
||||
|
@ -419,7 +432,13 @@ proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async.} =
|
|||
var stream = newBufferStream(writeHandler)
|
||||
asyncCheck readLoop(sconn, stream)
|
||||
var secured = newConnection(stream)
|
||||
secured.peerInfo = sconn.conn.peerInfo
|
||||
secured.closeEvent.wait()
|
||||
.addCallback(proc(udata: pointer) =
|
||||
trace "wrapped connection closed, closing upstream"
|
||||
if not sconn.closed:
|
||||
asyncCheck sconn.close()
|
||||
)
|
||||
secured.peerInfo.peerId = sconn.peerInfo.peerId
|
||||
result = secured
|
||||
|
||||
method init(s: Secio) {.gcsafe.} =
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
## those terms.
|
||||
|
||||
## This module implements an asynchronous buffer stream
|
||||
## which emulates physical async IO.
|
||||
## which emulates physical async IO.
|
||||
##
|
||||
## The stream is based on the standard library's `Deque`,
|
||||
## which is itself based on a ring buffer.
|
||||
|
@ -25,12 +25,12 @@
|
|||
## ordered and asynchronous. Reads are queued up in order
|
||||
## and are suspended when not enough data available. This
|
||||
## allows preserving backpressure while maintaining full
|
||||
## asynchrony. Both writting to the internal buffer with
|
||||
## asynchrony. Both writting to the internal buffer with
|
||||
## ``pushTo`` as well as reading with ``read*` methods,
|
||||
## will suspend until either the amount of elements in the
|
||||
## buffer goes below ``maxSize`` or more data becomes available.
|
||||
|
||||
import deques, tables, sequtils, math
|
||||
import deques, math
|
||||
import chronos
|
||||
import ../stream/lpstream
|
||||
|
||||
|
@ -38,33 +38,49 @@ const DefaultBufferSize* = 1024
|
|||
|
||||
type
|
||||
# TODO: figure out how to make this generic to avoid casts
|
||||
WriteHandler* = proc (data: seq[byte]): Future[void] {.gcsafe.}
|
||||
WriteHandler* = proc (data: seq[byte]): Future[void]
|
||||
|
||||
BufferStream* = ref object of LPStream
|
||||
maxSize*: int # buffer's max size in bytes
|
||||
readBuf: Deque[byte] # a deque is based on a ring buffer
|
||||
readBuf: Deque[byte] # this is a ring buffer based dequeue, this makes it perfect as the backing store here
|
||||
readReqs: Deque[Future[void]] # use dequeue to fire reads in order
|
||||
dataReadEvent: AsyncEvent
|
||||
writeHandler*: WriteHandler
|
||||
lock: AsyncLock
|
||||
isPiped: bool
|
||||
|
||||
proc requestReadBytes(s: BufferStream): Future[void] =
|
||||
AlreadyPipedError* = object of CatchableError
|
||||
NotWritableError* = object of CatchableError
|
||||
|
||||
proc newAlreadyPipedError*(): ref Exception {.inline.} =
|
||||
result = newException(AlreadyPipedError, "stream already piped")
|
||||
|
||||
proc newNotWritableError*(): ref Exception {.inline.} =
|
||||
result = newException(NotWritableError, "stream is not writable")
|
||||
|
||||
proc requestReadBytes(s: BufferStream): Future[void] =
|
||||
## create a future that will complete when more
|
||||
## data becomes available in the read buffer
|
||||
result = newFuture[void]()
|
||||
s.readReqs.addLast(result)
|
||||
|
||||
proc initBufferStream*(s: BufferStream, handler: WriteHandler, size: int = DefaultBufferSize) =
|
||||
proc initBufferStream*(s: BufferStream,
|
||||
handler: WriteHandler = nil,
|
||||
size: int = DefaultBufferSize) =
|
||||
s.maxSize = if isPowerOfTwo(size): size else: nextPowerOfTwo(size)
|
||||
s.readBuf = initDeque[byte](s.maxSize)
|
||||
s.readReqs = initDeque[Future[void]]()
|
||||
s.dataReadEvent = newAsyncEvent()
|
||||
s.lock = newAsyncLock()
|
||||
s.writeHandler = handler
|
||||
s.closeEvent = newAsyncEvent()
|
||||
|
||||
proc newBufferStream*(handler: WriteHandler, size: int = DefaultBufferSize): BufferStream =
|
||||
proc newBufferStream*(handler: WriteHandler = nil,
|
||||
size: int = DefaultBufferSize): BufferStream =
|
||||
new result
|
||||
result.initBufferStream(handler, size)
|
||||
|
||||
proc popFirst*(s: BufferStream): byte =
|
||||
proc popFirst*(s: BufferStream): byte =
|
||||
result = s.readBuf.popFirst()
|
||||
s.dataReadEvent.fire()
|
||||
|
||||
|
@ -78,15 +94,24 @@ proc shrink(s: BufferStream, fromFirst = 0, fromLast = 0) =
|
|||
|
||||
proc len*(s: BufferStream): int = s.readBuf.len
|
||||
|
||||
proc pushTo*(s: BufferStream, data: seq[byte]) {.async, gcsafe.} =
|
||||
proc pushTo*(s: BufferStream, data: seq[byte]) {.async.} =
|
||||
## Write bytes to internal read buffer, use this to fill up the
|
||||
## buffer with data.
|
||||
##
|
||||
## This method is async and will wait until all data has been
|
||||
## written to the internal buffer; this is done so that backpressure
|
||||
## is preserved.
|
||||
##
|
||||
|
||||
await s.lock.acquire()
|
||||
var index = 0
|
||||
while true:
|
||||
|
||||
# give readers a chance free up the buffer
|
||||
# it it's full.
|
||||
if s.readBuf.len >= s.maxSize:
|
||||
await sleepAsync(10.millis)
|
||||
|
||||
while index < data.len and s.readBuf.len < s.maxSize:
|
||||
s.readBuf.addLast(data[index])
|
||||
inc(index)
|
||||
|
@ -94,18 +119,20 @@ proc pushTo*(s: BufferStream, data: seq[byte]) {.async, gcsafe.} =
|
|||
# resolve the next queued read request
|
||||
if s.readReqs.len > 0:
|
||||
s.readReqs.popFirst().complete()
|
||||
|
||||
|
||||
if index >= data.len:
|
||||
break
|
||||
|
||||
|
||||
# if we couldn't transfer all the data to the
|
||||
# internal buf wait on a read event
|
||||
await s.dataReadEvent.wait()
|
||||
s.lock.release()
|
||||
|
||||
method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async, gcsafe.} =
|
||||
method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async.} =
|
||||
## Read all bytes (n <= 0) or exactly `n` bytes from buffer
|
||||
##
|
||||
## This procedure allocates buffer seq[byte] and return it as result.
|
||||
##
|
||||
var size = if n > 0: n else: s.readBuf.len()
|
||||
var index = 0
|
||||
while index < size:
|
||||
|
@ -116,25 +143,26 @@ method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async, gcsafe.} =
|
|||
if index < size:
|
||||
await s.requestReadBytes()
|
||||
|
||||
method readExactly*(s: BufferStream,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[void] {.async, gcsafe.} =
|
||||
method readExactly*(s: BufferStream,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[void] {.async.} =
|
||||
## Read exactly ``nbytes`` bytes from read-only stream ``rstream`` and store
|
||||
## it to ``pbytes``.
|
||||
##
|
||||
## If EOF is received and ``nbytes`` is not yet read, the procedure
|
||||
## will raise ``LPStreamIncompleteError``.
|
||||
let buff = await s.read(nbytes)
|
||||
##
|
||||
var buff = await s.read(nbytes)
|
||||
if nbytes > buff.len():
|
||||
raise newLPStreamIncompleteError()
|
||||
|
||||
copyMem(pbytes, unsafeAddr buff[0], nbytes)
|
||||
copyMem(pbytes, addr buff[0], nbytes)
|
||||
|
||||
method readLine*(s: BufferStream,
|
||||
limit = 0,
|
||||
sep = "\r\n"):
|
||||
Future[string] {.async, gcsafe.} =
|
||||
sep = "\r\n"):
|
||||
Future[string] {.async.} =
|
||||
## Read one line from read-only stream ``rstream``, where ``"line"`` is a
|
||||
## sequence of bytes ending with ``sep`` (default is ``"\r\n"``).
|
||||
##
|
||||
|
@ -146,6 +174,7 @@ method readLine*(s: BufferStream,
|
|||
##
|
||||
## If ``limit`` more then 0, then result string will be limited to ``limit``
|
||||
## bytes.
|
||||
##
|
||||
result = ""
|
||||
var lim = if limit <= 0: -1 else: limit
|
||||
var state = 0
|
||||
|
@ -170,14 +199,15 @@ method readLine*(s: BufferStream,
|
|||
method readOnce*(s: BufferStream,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[int] {.async, gcsafe.} =
|
||||
Future[int] {.async.} =
|
||||
## Perform one read operation on read-only stream ``rstream``.
|
||||
##
|
||||
## If internal buffer is not empty, ``nbytes`` bytes will be transferred from
|
||||
## internal buffer, otherwise it will wait until some bytes will be received.
|
||||
##
|
||||
if s.readBuf.len == 0:
|
||||
await s.requestReadBytes()
|
||||
|
||||
|
||||
var len = if nbytes > s.readBuf.len: s.readBuf.len else: nbytes
|
||||
await s.readExactly(pbytes, len)
|
||||
result = len
|
||||
|
@ -186,7 +216,7 @@ method readUntil*(s: BufferStream,
|
|||
pbytes: pointer,
|
||||
nbytes: int,
|
||||
sep: seq[byte]):
|
||||
Future[int] {.async, gcsafe.} =
|
||||
Future[int] {.async.} =
|
||||
## Read data from the read-only stream ``rstream`` until separator ``sep`` is
|
||||
## found.
|
||||
##
|
||||
|
@ -200,6 +230,7 @@ method readUntil*(s: BufferStream,
|
|||
## will raise ``LPStreamLimitError``.
|
||||
##
|
||||
## Procedure returns actual number of bytes read.
|
||||
##
|
||||
var
|
||||
dest = cast[ptr UncheckedArray[byte]](pbytes)
|
||||
state = 0
|
||||
|
@ -231,22 +262,22 @@ method readUntil*(s: BufferStream,
|
|||
else:
|
||||
s.shrink(datalen)
|
||||
|
||||
method write*(s: BufferStream,
|
||||
pbytes: pointer,
|
||||
nbytes: int): Future[void]
|
||||
{.gcsafe.} =
|
||||
method write*(s: BufferStream,
|
||||
pbytes: pointer,
|
||||
nbytes: int): Future[void] =
|
||||
## Consume (discard) all bytes (n <= 0) or ``n`` bytes from read-only stream
|
||||
## ``rstream``.
|
||||
##
|
||||
## Return number of bytes actually consumed (discarded).
|
||||
##
|
||||
var buf: seq[byte] = newSeq[byte](nbytes)
|
||||
copyMem(addr buf[0], pbytes, nbytes)
|
||||
result = s.writeHandler(buf)
|
||||
if not isNil(s.writeHandler):
|
||||
result = s.writeHandler(buf)
|
||||
|
||||
method write*(s: BufferStream,
|
||||
msg: string,
|
||||
msglen = -1): Future[void]
|
||||
{.gcsafe.} =
|
||||
msglen = -1): Future[void] =
|
||||
## Write string ``sbytes`` of length ``msglen`` to writer stream ``wstream``.
|
||||
##
|
||||
## String ``sbytes`` must not be zero-length.
|
||||
|
@ -254,14 +285,15 @@ method write*(s: BufferStream,
|
|||
## If ``msglen < 0`` whole string ``sbytes`` will be writen to stream.
|
||||
## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to
|
||||
## stream.
|
||||
##
|
||||
var buf = ""
|
||||
shallowCopy(buf, if msglen > 0: msg[0..<msglen] else: msg)
|
||||
result = s.writeHandler(cast[seq[byte]](buf))
|
||||
if not isNil(s.writeHandler):
|
||||
result = s.writeHandler(cast[seq[byte]](buf))
|
||||
|
||||
method write*(s: BufferStream,
|
||||
msg: seq[byte],
|
||||
msglen = -1): Future[void]
|
||||
{.gcsafe.} =
|
||||
msglen = -1): Future[void] =
|
||||
## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer
|
||||
## stream ``wstream``.
|
||||
##
|
||||
|
@ -270,13 +302,56 @@ method write*(s: BufferStream,
|
|||
## If ``msglen < 0`` whole sequence ``sbytes`` will be writen to stream.
|
||||
## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to
|
||||
## stream.
|
||||
##
|
||||
var buf: seq[byte]
|
||||
shallowCopy(buf, if msglen > 0: msg[0..<msglen] else: msg)
|
||||
result = s.writeHandler(buf)
|
||||
if not isNil(s.writeHandler):
|
||||
result = s.writeHandler(buf)
|
||||
|
||||
method close*(s: BufferStream) {.async, gcsafe.} =
|
||||
proc pipe*(s: BufferStream,
|
||||
target: BufferStream): BufferStream =
|
||||
## pipe the write end of this stream to
|
||||
## be the source of the target stream
|
||||
##
|
||||
## Note that this only works with the LPStream
|
||||
## interface methods `read*` and `write` are
|
||||
## piped.
|
||||
##
|
||||
if s.isPiped:
|
||||
raise newAlreadyPipedError()
|
||||
|
||||
s.isPiped = true
|
||||
let oldHandler = target.writeHandler
|
||||
proc handler(data: seq[byte]) {.async, closure.} =
|
||||
if not isNil(oldHandler):
|
||||
await oldHandler(data)
|
||||
|
||||
# if we're piping to self,
|
||||
# then add the data to the
|
||||
# buffer directly and fire
|
||||
# the read event
|
||||
if s == target:
|
||||
for b in data:
|
||||
s.readBuf.addLast(b)
|
||||
|
||||
# notify main loop of available
|
||||
# data
|
||||
s.dataReadEvent.fire()
|
||||
else:
|
||||
await target.pushTo(data)
|
||||
|
||||
s.writeHandler = handler
|
||||
result = target
|
||||
|
||||
proc `|`*(s: BufferStream, target: BufferStream): BufferStream =
|
||||
## pipe operator to make piping less verbose
|
||||
pipe(s, target)
|
||||
|
||||
method close*(s: BufferStream) {.async.} =
|
||||
## close the stream and clear the buffer
|
||||
for r in s.readReqs:
|
||||
r.cancel()
|
||||
s.dataReadEvent.fire()
|
||||
s.readBuf.clear()
|
||||
s.closed = true
|
||||
s.closeEvent.fire()
|
||||
s.isClosed = true
|
||||
|
|
|
@ -26,40 +26,63 @@ proc newChronosStream*(server: StreamServer,
|
|||
result.client = client
|
||||
result.reader = newAsyncStreamReader(client)
|
||||
result.writer = newAsyncStreamWriter(client)
|
||||
result.closed = false
|
||||
result.closeEvent = newAsyncEvent()
|
||||
|
||||
method read*(s: ChronosStream, n = -1): Future[seq[byte]] {.async.} =
|
||||
if s.reader.atEof:
|
||||
raise newLPStreamClosedError()
|
||||
|
||||
method read*(s: ChronosStream, n = -1): Future[seq[byte]] {.async, gcsafe.} =
|
||||
try:
|
||||
result = await s.reader.read(n)
|
||||
except AsyncStreamReadError as exc:
|
||||
raise newLPStreamReadError(exc.par)
|
||||
except AsyncStreamIncorrectError as exc:
|
||||
raise newLPStreamIncorrectError(exc.msg)
|
||||
|
||||
method readExactly*(s: ChronosStream,
|
||||
pbytes: pointer,
|
||||
nbytes: int): Future[void] {.async, gcsafe.} =
|
||||
nbytes: int): Future[void] {.async.} =
|
||||
if s.reader.atEof:
|
||||
raise newLPStreamClosedError()
|
||||
|
||||
try:
|
||||
await s.reader.readExactly(pbytes, nbytes)
|
||||
except AsyncStreamIncompleteError:
|
||||
raise newLPStreamIncompleteError()
|
||||
except AsyncStreamReadError as exc:
|
||||
raise newLPStreamReadError(exc.par)
|
||||
except AsyncStreamIncorrectError as exc:
|
||||
raise newLPStreamIncorrectError(exc.msg)
|
||||
|
||||
method readLine*(s: ChronosStream, limit = 0, sep = "\r\n"): Future[string] {.async.} =
|
||||
if s.reader.atEof:
|
||||
raise newLPStreamClosedError()
|
||||
|
||||
method readLine*(s: ChronosStream, limit = 0, sep = "\r\n"): Future[string] {.async, gcsafe.} =
|
||||
try:
|
||||
result = await s.reader.readLine(limit, sep)
|
||||
except AsyncStreamReadError as exc:
|
||||
raise newLPStreamReadError(exc.par)
|
||||
except AsyncStreamIncorrectError as exc:
|
||||
raise newLPStreamIncorrectError(exc.msg)
|
||||
|
||||
method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} =
|
||||
if s.reader.atEof:
|
||||
raise newLPStreamClosedError()
|
||||
|
||||
method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async, gcsafe.} =
|
||||
try:
|
||||
result = await s.reader.readOnce(pbytes, nbytes)
|
||||
except AsyncStreamReadError as exc:
|
||||
raise newLPStreamReadError(exc.par)
|
||||
except AsyncStreamIncorrectError as exc:
|
||||
raise newLPStreamIncorrectError(exc.msg)
|
||||
|
||||
method readUntil*(s: ChronosStream,
|
||||
pbytes: pointer,
|
||||
nbytes: int,
|
||||
sep: seq[byte]): Future[int] {.async, gcsafe.} =
|
||||
sep: seq[byte]): Future[int] {.async.} =
|
||||
if s.reader.atEof:
|
||||
raise newLPStreamClosedError()
|
||||
|
||||
try:
|
||||
result = await s.reader.readUntil(pbytes, nbytes, sep)
|
||||
except AsyncStreamIncompleteError:
|
||||
|
@ -68,36 +91,62 @@ method readUntil*(s: ChronosStream,
|
|||
raise newLPStreamLimitError()
|
||||
except LPStreamReadError as exc:
|
||||
raise newLPStreamReadError(exc.par)
|
||||
except AsyncStreamIncorrectError as exc:
|
||||
raise newLPStreamIncorrectError(exc.msg)
|
||||
|
||||
method write*(s: ChronosStream, pbytes: pointer, nbytes: int) {.async.} =
|
||||
if s.writer.atEof:
|
||||
raise newLPStreamClosedError()
|
||||
|
||||
method write*(s: ChronosStream, pbytes: pointer, nbytes: int) {.async, gcsafe.} =
|
||||
try:
|
||||
await s.writer.write(pbytes, nbytes)
|
||||
except AsyncStreamWriteError as exc:
|
||||
raise newLPStreamWriteError(exc.par)
|
||||
except AsyncStreamIncompleteError:
|
||||
raise newLPStreamIncompleteError()
|
||||
except AsyncStreamIncorrectError as exc:
|
||||
raise newLPStreamIncorrectError(exc.msg)
|
||||
|
||||
method write*(s: ChronosStream, msg: string, msglen = -1) {.async.} =
|
||||
if s.writer.atEof:
|
||||
raise newLPStreamClosedError()
|
||||
|
||||
method write*(s: ChronosStream, msg: string, msglen = -1) {.async, gcsafe.} =
|
||||
try:
|
||||
await s.writer.write(msg, msglen)
|
||||
except AsyncStreamWriteError as exc:
|
||||
raise newLPStreamWriteError(exc.par)
|
||||
except AsyncStreamIncompleteError:
|
||||
raise newLPStreamIncompleteError()
|
||||
except AsyncStreamIncorrectError as exc:
|
||||
raise newLPStreamIncorrectError(exc.msg)
|
||||
|
||||
method write*(s: ChronosStream, msg: seq[byte], msglen = -1) {.async.} =
|
||||
if s.writer.atEof:
|
||||
raise newLPStreamClosedError()
|
||||
|
||||
method write*(s: ChronosStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
|
||||
try:
|
||||
await s.writer.write(msg, msglen)
|
||||
except AsyncStreamWriteError as exc:
|
||||
raise newLPStreamWriteError(exc.par)
|
||||
except AsyncStreamIncompleteError:
|
||||
raise newLPStreamIncompleteError()
|
||||
except AsyncStreamIncorrectError as exc:
|
||||
raise newLPStreamIncorrectError(exc.msg)
|
||||
|
||||
method close*(s: ChronosStream) {.async, gcsafe.} =
|
||||
method closed*(s: ChronosStream): bool {.inline.} =
|
||||
# TODO: we might only need to check for reader's EOF
|
||||
result = s.reader.atEof()
|
||||
|
||||
method close*(s: ChronosStream) {.async.} =
|
||||
if not s.closed:
|
||||
trace "shutting down server", address = $s.client.remoteAddress()
|
||||
await s.writer.finish()
|
||||
await s.writer.closeWait()
|
||||
await s.reader.closeWait()
|
||||
await s.client.closeWait()
|
||||
s.closed = true
|
||||
trace "shutting chronos stream", address = $s.client.remoteAddress()
|
||||
if not s.writer.closed():
|
||||
await s.writer.closeWait()
|
||||
|
||||
if not s.reader.closed():
|
||||
await s.reader.closeWait()
|
||||
|
||||
if not s.client.closed():
|
||||
await s.client.closeWait()
|
||||
|
||||
s.closeEvent.fire()
|
||||
|
|
|
@ -11,7 +11,8 @@ import chronos
|
|||
|
||||
type
|
||||
LPStream* = ref object of RootObj
|
||||
closed*: bool
|
||||
isClosed*: bool
|
||||
closeEvent*: AsyncEvent
|
||||
|
||||
LPStreamError* = object of CatchableError
|
||||
LPStreamIncompleteError* = object of LPStreamError
|
||||
|
@ -47,40 +48,43 @@ proc newLPStreamIncorrectError*(m: string): ref Exception {.inline.} =
|
|||
proc newLPStreamClosedError*(): ref Exception {.inline.} =
|
||||
result = newException(LPStreamClosedError, "Stream closed!")
|
||||
|
||||
method closed*(s: LPStream): bool {.base, inline.} =
|
||||
s.isClosed
|
||||
|
||||
method read*(s: LPStream, n = -1): Future[seq[byte]]
|
||||
{.base, async, gcsafe.} =
|
||||
{.base, async.} =
|
||||
doAssert(false, "not implemented!")
|
||||
|
||||
method readExactly*(s: LPStream, pbytes: pointer, nbytes: int): Future[void]
|
||||
{.base, async, gcsafe.} =
|
||||
{.base, async.} =
|
||||
doAssert(false, "not implemented!")
|
||||
|
||||
method readLine*(s: LPStream, limit = 0, sep = "\r\n"): Future[string]
|
||||
{.base, async, gcsafe.} =
|
||||
{.base, async.} =
|
||||
doAssert(false, "not implemented!")
|
||||
|
||||
method readOnce*(s: LPStream, pbytes: pointer, nbytes: int): Future[int]
|
||||
{.base, async, gcsafe.} =
|
||||
{.base, async.} =
|
||||
doAssert(false, "not implemented!")
|
||||
|
||||
method readUntil*(s: LPStream,
|
||||
pbytes: pointer, nbytes: int,
|
||||
sep: seq[byte]): Future[int]
|
||||
{.base, async, gcsafe.} =
|
||||
{.base, async.} =
|
||||
doAssert(false, "not implemented!")
|
||||
|
||||
method write*(s: LPStream, pbytes: pointer, nbytes: int)
|
||||
{.base, async, gcsafe.} =
|
||||
{.base, async.} =
|
||||
doAssert(false, "not implemented!")
|
||||
|
||||
method write*(s: LPStream, msg: string, msglen = -1)
|
||||
{.base, async, gcsafe.} =
|
||||
{.base, async.} =
|
||||
doAssert(false, "not implemented!")
|
||||
|
||||
method write*(s: LPStream, msg: seq[byte], msglen = -1)
|
||||
{.base, async, gcsafe.} =
|
||||
{.base, async.} =
|
||||
doAssert(false, "not implemented!")
|
||||
|
||||
method close*(s: LPStream)
|
||||
{.base, async, gcsafe.} =
|
||||
{.base, async.} =
|
||||
doAssert(false, "not implemented!")
|
||||
|
|
|
@ -11,13 +11,11 @@ import tables, sequtils, options, strformat
|
|||
import chronos, chronicles
|
||||
import connection,
|
||||
transports/transport,
|
||||
stream/lpstream,
|
||||
multistream,
|
||||
protocols/protocol,
|
||||
protocols/secure/secure,
|
||||
protocols/secure/plaintext, # for plain text
|
||||
peerinfo,
|
||||
multiaddress,
|
||||
protocols/identify,
|
||||
protocols/pubsub/pubsub,
|
||||
muxers/muxer,
|
||||
|
@ -26,6 +24,12 @@ import connection,
|
|||
logScope:
|
||||
topic = "Switch"
|
||||
|
||||
#TODO: General note - use a finite state machine to manage the different
|
||||
# steps of connections establishing and upgrading. This makes everything
|
||||
# more robust and less prone to ordering attacks - i.e. muxing can come if
|
||||
# and only if the channel has been secured (i.e. if a secure manager has been
|
||||
# previously provided)
|
||||
|
||||
type
|
||||
NoPubSubException = object of CatchableError
|
||||
|
||||
|
@ -48,7 +52,6 @@ proc newNoPubSubException(): ref Exception {.inline.} =
|
|||
proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
||||
## secure the incoming connection
|
||||
|
||||
# plaintext for now, doesn't do anything
|
||||
let managers = toSeq(s.secureManagers.keys)
|
||||
if managers.len == 0:
|
||||
raise newException(CatchableError, "No secure managers registered!")
|
||||
|
@ -62,20 +65,21 @@ proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
|||
proc identify(s: Switch, conn: Connection): Future[PeerInfo] {.async, gcsafe.} =
|
||||
## identify the connection
|
||||
|
||||
result = conn.peerInfo
|
||||
try:
|
||||
if (await s.ms.select(conn, s.identity.codec)):
|
||||
let info = await s.identity.identify(conn, conn.peerInfo)
|
||||
|
||||
if info.pubKey.isSome:
|
||||
result.peerId = some(PeerID.init(info.pubKey.get())) # we might not have a peerId at all
|
||||
|
||||
trace "identify: identified remote peer", peer = result.id
|
||||
|
||||
if info.addrs.len > 0:
|
||||
result.addrs = info.addrs
|
||||
|
||||
|
||||
if info.protos.len > 0:
|
||||
result.protocols = info.protos
|
||||
|
||||
trace "identify: identified remote peer ", peer = result.peerId.get().pretty
|
||||
except IdentityInvalidMsgError as exc:
|
||||
error "identify: invalid message", msg = exc.msg
|
||||
except IdentityNoMatchError as exc:
|
||||
|
@ -100,22 +104,23 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} =
|
|||
muxer.streamHandler = s.streamHandler
|
||||
|
||||
# new stream for identify
|
||||
let stream = await muxer.newStream()
|
||||
var stream = await muxer.newStream()
|
||||
let handlerFut = muxer.handle()
|
||||
|
||||
# add muxer handler cleanup proc
|
||||
handlerFut.addCallback(
|
||||
proc(udata: pointer = nil) {.gcsafe.} =
|
||||
trace "mux: Muxer handler completed for peer ",
|
||||
peer = conn.peerInfo.peerId.get().pretty
|
||||
trace "muxer handler completed for peer",
|
||||
peer = conn.peerInfo.id
|
||||
)
|
||||
|
||||
# do identify first, so that we have a
|
||||
# PeerInfo in case we didn't before
|
||||
conn.peerInfo = await s.identify(stream)
|
||||
await stream.close() # close idenity stream
|
||||
|
||||
trace "connection's peerInfo", peerInfo = conn.peerInfo.peerId
|
||||
|
||||
await stream.close() # close identify stream
|
||||
|
||||
trace "connection's peerInfo", peerInfo = conn.peerInfo
|
||||
|
||||
# store it in muxed connections if we have a peer for it
|
||||
# TODO: We should make sure that this are cleaned up properly
|
||||
|
@ -123,43 +128,42 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} =
|
|||
# happen once secio is in place, but still something to keep
|
||||
# in mind
|
||||
if conn.peerInfo.peerId.isSome:
|
||||
trace "adding muxer for peer", peer = conn.peerInfo.peerId.get().pretty
|
||||
s.muxed[conn.peerInfo.peerId.get().pretty] = muxer
|
||||
trace "adding muxer for peer", peer = conn.peerInfo.id
|
||||
s.muxed[conn.peerInfo.id] = muxer
|
||||
|
||||
proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} =
|
||||
if conn.peerInfo.peerId.isSome:
|
||||
let id = conn.peerInfo.peerId.get().pretty
|
||||
if s.muxed.contains(id):
|
||||
await s.muxed[id].close
|
||||
|
||||
if s.connections.contains(id):
|
||||
let id = conn.peerInfo.id
|
||||
trace "cleaning up connection for peer", peerId = id
|
||||
if id in s.muxed:
|
||||
await s.muxed[id].close()
|
||||
s.muxed.del(id)
|
||||
|
||||
if id in s.connections:
|
||||
await s.connections[id].close()
|
||||
s.connections.del(id)
|
||||
|
||||
proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Option[Connection]] {.async, gcsafe.} =
|
||||
# if there is a muxer for the connection
|
||||
# use it instead to create a muxed stream
|
||||
if s.muxed.contains(peerInfo.peerId.get().pretty):
|
||||
trace "connection is muxed, retriving muxer and setting up a stream"
|
||||
let muxer = s.muxed[peerInfo.peerId.get().pretty]
|
||||
if peerInfo.id in s.muxed:
|
||||
trace "connection is muxed, setting up a stream"
|
||||
let muxer = s.muxed[peerInfo.id]
|
||||
let conn = await muxer.newStream()
|
||||
result = some(conn)
|
||||
|
||||
proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
||||
trace "handling connection", conn = conn
|
||||
result = conn
|
||||
## perform upgrade flow
|
||||
if result.peerInfo.peerId.isSome:
|
||||
let id = result.peerInfo.peerId.get().pretty
|
||||
if s.connections.contains(id):
|
||||
# if we already have a connection for this peer,
|
||||
# close the incoming connection and return the
|
||||
# existing one
|
||||
await result.close()
|
||||
return s.connections[id]
|
||||
s.connections[id] = result
|
||||
|
||||
result = await s.secure(conn) # secure the connection
|
||||
# don't mux/secure twise
|
||||
if conn.peerInfo.peerId.isSome and
|
||||
conn.peerInfo.id in s.muxed:
|
||||
return
|
||||
|
||||
result = await s.secure(result) # secure the connection
|
||||
await s.mux(result) # mux it if possible
|
||||
s.connections[conn.peerInfo.id] = result
|
||||
|
||||
proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
|
||||
trace "upgrading incoming connection"
|
||||
|
@ -192,42 +196,57 @@ proc dial*(s: Switch,
|
|||
peer: PeerInfo,
|
||||
proto: string = ""):
|
||||
Future[Connection] {.async.} =
|
||||
trace "dialing peer", peer = peer.peerId.get().pretty
|
||||
let id = peer.id
|
||||
trace "dialing peer", peer = id
|
||||
for t in s.transports: # for each transport
|
||||
for a in peer.addrs: # for each address
|
||||
if t.handles(a): # check if it can dial it
|
||||
result = await t.dial(a)
|
||||
# make sure to assign the peer to the connection
|
||||
result.peerInfo = peer
|
||||
if id notin s.connections:
|
||||
trace "dialing address", address = $a
|
||||
result = await t.dial(a)
|
||||
# make sure to assign the peer to the connection
|
||||
result.peerInfo = peer
|
||||
result = await s.upgradeOutgoing(result)
|
||||
result.closeEvent.wait().addCallback(
|
||||
proc(udata: pointer) =
|
||||
asyncCheck s.cleanupConn(result)
|
||||
)
|
||||
|
||||
let stream = await s.getMuxedStream(peer)
|
||||
if stream.isSome:
|
||||
trace "connection is muxed, return muxed stream"
|
||||
result = stream.get()
|
||||
if proto.len > 0 and not result.closed:
|
||||
let stream = await s.getMuxedStream(peer)
|
||||
if stream.isSome:
|
||||
trace "connection is muxed, return muxed stream"
|
||||
result = stream.get()
|
||||
trace "attempting to select remote", proto = proto
|
||||
|
||||
trace "dial: attempting to select remote ", proto = proto
|
||||
if not (await s.ms.select(result, proto)):
|
||||
error "dial: Unable to select protocol: ", proto = proto
|
||||
raise newException(CatchableError,
|
||||
&"Unable to select protocol: {proto}")
|
||||
if not (await s.ms.select(result, proto)):
|
||||
error "unable to select protocol: ", proto = proto
|
||||
raise newException(CatchableError,
|
||||
&"unable to select protocol: {proto}")
|
||||
|
||||
break # don't dial more than one addr on the same transport
|
||||
|
||||
proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} =
|
||||
if isNil(proto.handler):
|
||||
raise newException(CatchableError,
|
||||
raise newException(CatchableError,
|
||||
"Protocol has to define a handle method or proc")
|
||||
|
||||
if proto.codec.len == 0:
|
||||
raise newException(CatchableError,
|
||||
raise newException(CatchableError,
|
||||
"Protocol has to define a codec string")
|
||||
|
||||
s.ms.addHandler(proto.codec, proto)
|
||||
|
||||
proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
|
||||
trace "starting switch"
|
||||
|
||||
proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} =
|
||||
try:
|
||||
await s.upgradeIncoming(conn) # perform upgrade on incoming connection
|
||||
except CatchableError as exc:
|
||||
trace "exception occured", exc = exc.msg
|
||||
finally:
|
||||
await conn.close()
|
||||
await s.cleanupConn(conn)
|
||||
|
||||
var startFuts: seq[Future[void]]
|
||||
|
@ -237,10 +256,13 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
|
|||
var server = await t.listen(a, handle)
|
||||
s.peerInfo.addrs[i] = t.ma # update peer's address
|
||||
startFuts.add(server)
|
||||
|
||||
result = startFuts # listen for incoming connections
|
||||
|
||||
proc stop*(s: Switch) {.async.} =
|
||||
await allFutures(toSeq(s.connections.values).mapIt(it.close()))
|
||||
trace "stopping switch"
|
||||
|
||||
await allFutures(toSeq(s.connections.values).mapIt(s.cleanupConn(it)))
|
||||
await allFutures(s.transports.mapIt(it.close()))
|
||||
|
||||
proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =
|
||||
|
@ -253,14 +275,14 @@ proc subscribe*(s: Switch, topic: string, handler: TopicHandler): Future[void] {
|
|||
## subscribe to a pubsub topic
|
||||
if s.pubSub.isNone:
|
||||
raise newNoPubSubException()
|
||||
|
||||
|
||||
result = s.pubSub.get().subscribe(topic, handler)
|
||||
|
||||
proc unsubscribe*(s: Switch, topics: seq[TopicPair]): Future[void] {.gcsafe.} =
|
||||
## unsubscribe from topics
|
||||
if s.pubSub.isNone:
|
||||
raise newNoPubSubException()
|
||||
|
||||
|
||||
result = s.pubSub.get().unsubscribe(topics)
|
||||
|
||||
proc publish*(s: Switch, topic: string, data: seq[byte]): Future[void] {.gcsafe.} =
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
import chronos, chronicles
|
||||
import chronos, chronicles, sequtils
|
||||
import transport,
|
||||
../wire,
|
||||
../connection,
|
||||
|
@ -78,5 +78,5 @@ method dial*(t: TcpTransport,
|
|||
result = await t.connHandler(t.server, client, true)
|
||||
|
||||
method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} =
|
||||
## TODO: implement logic to properly discriminat TCP multiaddrs
|
||||
true
|
||||
if procCall Transport(t).handles(address):
|
||||
result = address.protocols.filterIt( it == multiCodec("tcp") ).len > 0
|
||||
|
|
|
@ -9,8 +9,7 @@
|
|||
|
||||
import sequtils
|
||||
import chronos, chronicles
|
||||
import ../peerinfo,
|
||||
../connection,
|
||||
import ../connection,
|
||||
../multiaddress,
|
||||
../multicodec
|
||||
|
||||
|
@ -62,9 +61,10 @@ method upgrade*(t: Transport) {.base, async, gcsafe.} =
|
|||
|
||||
method handles*(t: Transport, address: MultiAddress): bool {.base, gcsafe.} =
|
||||
## check if transport supportes the multiaddress
|
||||
# TODO: this should implement generic logic that would use the multicodec
|
||||
# declared in the multicodec field and set by each individual transport
|
||||
discard
|
||||
|
||||
# by default we skip circuit addresses to avoid
|
||||
# having to repeat the check in every transport
|
||||
address.protocols.filterIt( it == multiCodec("p2p-circuit") ).len == 0
|
||||
|
||||
method localAddress*(t: Transport): MultiAddress {.base, gcsafe.} =
|
||||
## get the local address of the transport in case started with 0.0.0.0:0
|
||||
|
|
|
@ -53,7 +53,17 @@ proc initVBuffer*(): VBuffer =
|
|||
## Initialize empty VBuffer.
|
||||
result.buffer = newSeqOfCap[byte](128)
|
||||
|
||||
proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) =
|
||||
proc writePBVarint*(vb: var VBuffer, value: PBSomeUVarint) =
|
||||
## Write ``value`` as variable unsigned integer.
|
||||
var length = 0
|
||||
var v = value and cast[type(value)](0xFFFF_FFFF_FFFF_FFFF)
|
||||
vb.buffer.setLen(len(vb.buffer) + vsizeof(v))
|
||||
let res = PB.putUVarint(toOpenArray(vb.buffer, vb.offset, len(vb.buffer) - 1),
|
||||
length, v)
|
||||
doAssert(res == VarintStatus.Success)
|
||||
vb.offset += length
|
||||
|
||||
proc writeLPVarint*(vb: var VBuffer, value: LPSomeUVarint) =
|
||||
## Write ``value`` as variable unsigned integer.
|
||||
var length = 0
|
||||
# LibP2P varint supports only 63 bits.
|
||||
|
@ -64,6 +74,9 @@ proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) =
|
|||
doAssert(res == VarintStatus.Success)
|
||||
vb.offset += length
|
||||
|
||||
proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) =
|
||||
writeLPVarint(vb, value)
|
||||
|
||||
proc writeSeq*[T: byte|char](vb: var VBuffer, value: openarray[T]) =
|
||||
## Write array ``value`` to buffer ``vb``, value will be prefixed with
|
||||
## varint length of the array.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import unittest, deques, sequtils, strformat
|
||||
import unittest, strformat
|
||||
import chronos
|
||||
import ../libp2p/stream/bufferstream
|
||||
|
||||
|
@ -220,7 +220,6 @@ suite "BufferStream":
|
|||
|
||||
test "reads should happen in order":
|
||||
proc testWritePtr(): Future[bool] {.async.} =
|
||||
var count = 1
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let buff = newBufferStream(writeHandler, 10)
|
||||
check buff.len == 0
|
||||
|
@ -245,3 +244,199 @@ suite "BufferStream":
|
|||
|
||||
check:
|
||||
waitFor(testWritePtr()) == true
|
||||
|
||||
test "pipe two streams without the `pipe` or `|` helpers":
|
||||
proc pipeTest(): Future[bool] {.async.} =
|
||||
proc writeHandler1(data: seq[byte]) {.async, gcsafe.}
|
||||
proc writeHandler2(data: seq[byte]) {.async, gcsafe.}
|
||||
|
||||
var buf1 = newBufferStream(writeHandler1)
|
||||
var buf2 = newBufferStream(writeHandler2)
|
||||
|
||||
proc writeHandler1(data: seq[byte]) {.async, gcsafe.} =
|
||||
var msg = cast[string](data)
|
||||
check msg == "Hello!"
|
||||
await buf2.pushTo(data)
|
||||
|
||||
proc writeHandler2(data: seq[byte]) {.async, gcsafe.} =
|
||||
var msg = cast[string](data)
|
||||
check msg == "Hello!"
|
||||
await buf1.pushTo(data)
|
||||
|
||||
var res1: seq[byte] = newSeq[byte](7)
|
||||
var readFut1 = buf1.readExactly(addr res1[0], 7)
|
||||
|
||||
var res2: seq[byte] = newSeq[byte](7)
|
||||
var readFut2 = buf2.readExactly(addr res2[0], 7)
|
||||
|
||||
await buf1.pushTo(cast[seq[byte]]("Hello2!"))
|
||||
await buf2.pushTo(cast[seq[byte]]("Hello1!"))
|
||||
|
||||
await allFutures(readFut1, readFut2)
|
||||
|
||||
check:
|
||||
res1 == cast[seq[byte]]("Hello2!")
|
||||
res2 == cast[seq[byte]]("Hello1!")
|
||||
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(pipeTest()) == true
|
||||
|
||||
test "pipe A -> B":
|
||||
proc pipeTest(): Future[bool] {.async.} =
|
||||
var buf1 = newBufferStream()
|
||||
var buf2 = buf1.pipe(newBufferStream())
|
||||
|
||||
var res1: seq[byte] = newSeq[byte](7)
|
||||
var readFut = buf2.readExactly(addr res1[0], 7)
|
||||
await buf1.write(cast[seq[byte]]("Hello1!"))
|
||||
await readFut
|
||||
|
||||
check:
|
||||
res1 == cast[seq[byte]]("Hello1!")
|
||||
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(pipeTest()) == true
|
||||
|
||||
test "pipe A -> B and B -> A":
|
||||
proc pipeTest(): Future[bool] {.async.} =
|
||||
var buf1 = newBufferStream()
|
||||
var buf2 = newBufferStream()
|
||||
|
||||
buf1 = buf1.pipe(buf2).pipe(buf1)
|
||||
|
||||
var res1: seq[byte] = newSeq[byte](7)
|
||||
var readFut1 = buf1.readExactly(addr res1[0], 7)
|
||||
|
||||
var res2: seq[byte] = newSeq[byte](7)
|
||||
var readFut2 = buf2.readExactly(addr res2[0], 7)
|
||||
|
||||
await buf1.write(cast[seq[byte]]("Hello1!"))
|
||||
await buf2.write(cast[seq[byte]]("Hello2!"))
|
||||
await allFutures(readFut1, readFut2)
|
||||
|
||||
check:
|
||||
res1 == cast[seq[byte]]("Hello2!")
|
||||
res2 == cast[seq[byte]]("Hello1!")
|
||||
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(pipeTest()) == true
|
||||
|
||||
test "pipe A -> A (echo)":
|
||||
proc pipeTest(): Future[bool] {.async.} =
|
||||
var buf1 = newBufferStream()
|
||||
|
||||
buf1 = buf1.pipe(buf1)
|
||||
|
||||
proc reader(): Future[seq[byte]] = buf1.read(6)
|
||||
proc writer(): Future[void] = buf1.write(cast[seq[byte]]("Hello!"))
|
||||
|
||||
var writerFut = writer()
|
||||
var readerFut = reader()
|
||||
|
||||
await writerFut
|
||||
check:
|
||||
(await readerFut) == cast[seq[byte]]("Hello!")
|
||||
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(pipeTest()) == true
|
||||
|
||||
test "pipe with `|` operator - A -> B":
|
||||
proc pipeTest(): Future[bool] {.async.} =
|
||||
var buf1 = newBufferStream()
|
||||
var buf2 = buf1 | newBufferStream()
|
||||
|
||||
var res1: seq[byte] = newSeq[byte](7)
|
||||
var readFut = buf2.readExactly(addr res1[0], 7)
|
||||
await buf1.write(cast[seq[byte]]("Hello1!"))
|
||||
await readFut
|
||||
|
||||
check:
|
||||
res1 == cast[seq[byte]]("Hello1!")
|
||||
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(pipeTest()) == true
|
||||
|
||||
test "pipe with `|` operator - A -> B and B -> A":
|
||||
proc pipeTest(): Future[bool] {.async.} =
|
||||
var buf1 = newBufferStream()
|
||||
var buf2 = newBufferStream()
|
||||
|
||||
buf1 = buf1 | buf2 | buf1
|
||||
|
||||
var res1: seq[byte] = newSeq[byte](7)
|
||||
var readFut1 = buf1.readExactly(addr res1[0], 7)
|
||||
|
||||
var res2: seq[byte] = newSeq[byte](7)
|
||||
var readFut2 = buf2.readExactly(addr res2[0], 7)
|
||||
|
||||
await buf1.write(cast[seq[byte]]("Hello1!"))
|
||||
await buf2.write(cast[seq[byte]]("Hello2!"))
|
||||
await allFutures(readFut1, readFut2)
|
||||
|
||||
check:
|
||||
res1 == cast[seq[byte]]("Hello2!")
|
||||
res2 == cast[seq[byte]]("Hello1!")
|
||||
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(pipeTest()) == true
|
||||
|
||||
test "pipe with `|` operator - A -> A (echo)":
|
||||
proc pipeTest(): Future[bool] {.async.} =
|
||||
var buf1 = newBufferStream()
|
||||
|
||||
buf1 = buf1 | buf1
|
||||
|
||||
proc reader(): Future[seq[byte]] = buf1.read(6)
|
||||
proc writer(): Future[void] = buf1.write(cast[seq[byte]]("Hello!"))
|
||||
|
||||
var writerFut = writer()
|
||||
var readerFut = reader()
|
||||
|
||||
await writerFut
|
||||
check:
|
||||
(await readerFut) == cast[seq[byte]]("Hello!")
|
||||
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(pipeTest()) == true
|
||||
|
||||
# TODO: Need to implement deadlock prevention when
|
||||
# piping to self
|
||||
test "pipe deadlock":
|
||||
proc pipeTest(): Future[bool] {.async.} =
|
||||
|
||||
var buf1 = newBufferStream(size = 5)
|
||||
|
||||
buf1 = buf1 | buf1
|
||||
|
||||
var count = 30000
|
||||
proc reader() {.async.} =
|
||||
while count > 0:
|
||||
discard await buf1.read(7)
|
||||
|
||||
proc writer() {.async.} =
|
||||
while count > 0:
|
||||
await buf1.write(cast[seq[byte]]("Hello2!"))
|
||||
count.dec
|
||||
|
||||
var writerFut = writer()
|
||||
var readerFut = reader()
|
||||
|
||||
await allFutures(readerFut, writerFut)
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(pipeTest()) == true
|
||||
|
|
|
@ -274,25 +274,15 @@ suite "Mplex":
|
|||
expect LPStreamClosedError:
|
||||
waitFor(testClosedForWrite())
|
||||
|
||||
test "half closed - channel should close for read":
|
||||
proc testClosedForRead(): Future[void] {.async.} =
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
|
||||
await chann.closedByRemote()
|
||||
asyncDiscard chann.read()
|
||||
|
||||
expect LPStreamClosedError:
|
||||
waitFor(testClosedForRead())
|
||||
|
||||
test "half closed - channel should close for read after eof":
|
||||
test "half closed - channel should close for read by remote":
|
||||
proc testClosedForRead(): Future[void] {.async.} =
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
|
||||
|
||||
await chann.pushTo(cast[seq[byte]]("Hello!"))
|
||||
await chann.close()
|
||||
let msg = await chann.read()
|
||||
asyncDiscard chann.read()
|
||||
await chann.closedByRemote()
|
||||
discard await chann.read() # this should work, since there is data in the buffer
|
||||
discard await chann.read() # this should throw
|
||||
|
||||
expect LPStreamClosedError:
|
||||
waitFor(testClosedForRead())
|
||||
|
@ -312,7 +302,7 @@ suite "Mplex":
|
|||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
|
||||
await chann.reset()
|
||||
asyncDiscard chann.read()
|
||||
await chann.write(cast[seq[byte]]("Hello!"))
|
||||
|
||||
expect LPStreamClosedError:
|
||||
waitFor(testResetWrite())
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import unittest, strutils, sequtils, sugar, strformat, options
|
||||
import unittest, strutils, sequtils, strformat, options
|
||||
import chronos
|
||||
import ../libp2p/connection,
|
||||
../libp2p/multistream,
|
||||
|
@ -51,7 +51,8 @@ method write*(s: TestSelectStream, msg: seq[byte], msglen = -1)
|
|||
method write*(s: TestSelectStream, msg: string, msglen = -1)
|
||||
{.async, gcsafe.} = discard
|
||||
|
||||
method close(s: TestSelectStream) {.async, gcsafe.} = s.closed = true
|
||||
method close(s: TestSelectStream) {.async, gcsafe.} =
|
||||
s.isClosed = true
|
||||
|
||||
proc newTestSelectStream(): TestSelectStream =
|
||||
new result
|
||||
|
@ -97,7 +98,8 @@ method write*(s: TestLsStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
|
|||
method write*(s: TestLsStream, msg: string, msglen = -1)
|
||||
{.async, gcsafe.} = discard
|
||||
|
||||
method close(s: TestLsStream) {.async, gcsafe.} = s.closed = true
|
||||
method close(s: TestLsStream) {.async, gcsafe.} =
|
||||
s.isClosed = true
|
||||
|
||||
proc newTestLsStream(ls: LsHandler): TestLsStream {.gcsafe.} =
|
||||
new result
|
||||
|
@ -143,7 +145,8 @@ method write*(s: TestNaStream, msg: string, msglen = -1) {.async, gcsafe.} =
|
|||
if s.step == 4:
|
||||
await s.na(msg)
|
||||
|
||||
method close(s: TestNaStream) {.async, gcsafe.} = s.closed = true
|
||||
method close(s: TestNaStream) {.async, gcsafe.} =
|
||||
s.isClosed = true
|
||||
|
||||
proc newTestNaStream(na: NaHandler): TestNaStream =
|
||||
new result
|
||||
|
|
|
@ -2,5 +2,11 @@ import unittest
|
|||
import testvarint, testbase32, testbase58, testbase64
|
||||
import testrsa, testecnist, tested25519, testsecp256k1, testcrypto
|
||||
import testmultibase, testmultihash, testmultiaddress, testcid, testpeer
|
||||
import testtransport, testmultistream, testbufferstream,
|
||||
testmplex, testidentify, testswitch, testpubsub
|
||||
|
||||
import testtransport,
|
||||
testmultistream,
|
||||
testbufferstream,
|
||||
testidentify,
|
||||
testswitch,
|
||||
testpubsub,
|
||||
testmplex
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import unittest, tables, options
|
||||
import chronos, chronicles
|
||||
import chronos
|
||||
import ../libp2p/[switch,
|
||||
multistream,
|
||||
protocols/identify,
|
||||
|
@ -36,7 +36,7 @@ method init(p: TestProto) {.gcsafe.} =
|
|||
|
||||
suite "Switch":
|
||||
test "e2e use switch":
|
||||
proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) =
|
||||
proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) {.gcsafe.}=
|
||||
let seckey = PrivateKey.random(RSA)
|
||||
var peerInfo: PeerInfo
|
||||
peerInfo.peerId = some(PeerID.init(seckey))
|
||||
|
@ -50,7 +50,11 @@ suite "Switch":
|
|||
let transports = @[Transport(newTransport(TcpTransport))]
|
||||
let muxers = [(MplexCodec, mplexProvider)].toTable()
|
||||
let secureManagers = [(SecioCodec, Secure(newSecio(seckey)))].toTable()
|
||||
let switch = newSwitch(peerInfo, transports, identify, muxers, secureManagers)
|
||||
let switch = newSwitch(peerInfo,
|
||||
transports,
|
||||
identify,
|
||||
muxers,
|
||||
secureManagers)
|
||||
result = (switch, peerInfo)
|
||||
|
||||
proc testSwitch(): Future[bool] {.async, gcsafe.} =
|
||||
|
|
Loading…
Reference in New Issue