Feat/conn cleanup (#41)

Backporting proper connection cleanup from #36 to align with latest chronos changes.

* add close event

* use proper varint encoding

* add proper channel cleanup in mplex

* add connection cleanup in secio

* tidy up

* add dollar operator

* fix tests

* don't close connections prematurely

* handle closing streams properly

* misc

* implement address filtering logic

* adding pipe tests

* don't use gcsafe if not needed

* misc

* proper connection cleanup and stream muxing

* re-enable pubsub tests
This commit is contained in:
Dmitriy Ryajov 2019-12-03 22:44:54 -06:00 committed by GitHub
parent 1df16bdbce
commit 903e79ede1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 765 additions and 294 deletions

View File

@ -7,7 +7,7 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import chronos, options, chronicles import chronos, chronicles
import peerinfo, import peerinfo,
multiaddress, multiaddress,
stream/lpstream, stream/lpstream,
@ -26,15 +26,28 @@ type
InvalidVarintException = object of LPStreamError InvalidVarintException = object of LPStreamError
proc newInvalidVarintException*(): ref InvalidVarintException = proc newInvalidVarintException*(): ref InvalidVarintException =
result = newException(InvalidVarintException, "unable to prase varint") newException(InvalidVarintException, "unable to prase varint")
proc newConnection*(stream: LPStream): Connection = proc newConnection*(stream: LPStream): Connection =
## create a new Connection for the specified async reader/writer ## create a new Connection for the specified async reader/writer
new result new result
result.stream = stream result.stream = stream
result.closeEvent = newAsyncEvent()
# bind stream's close event to connection's close
# to ensure correct close propagation
let this = result
if not isNil(result.stream.closeEvent):
result.stream.closeEvent.wait().
addCallback(
proc (udata: pointer) =
if not this.closed:
trace "closing this connection because wrapped stream closed"
asyncCheck this.close()
)
method read*(s: Connection, n = -1): Future[seq[byte]] {.gcsafe.} = method read*(s: Connection, n = -1): Future[seq[byte]] {.gcsafe.} =
result = s.stream.read(n) s.stream.read(n)
method readExactly*(s: Connection, method readExactly*(s: Connection,
pbytes: pointer, pbytes: pointer,
@ -79,9 +92,20 @@ method write*(s: Connection,
Future[void] {.gcsafe.} = Future[void] {.gcsafe.} =
s.stream.write(msg, msglen) s.stream.write(msg, msglen)
method closed*(s: Connection): bool =
if isNil(s.stream):
return false
result = s.stream.closed
method close*(s: Connection) {.async, gcsafe.} = method close*(s: Connection) {.async, gcsafe.} =
await s.stream.close() trace "closing connection"
s.closed = true if not s.closed:
if not isNil(s.stream) and not s.stream.closed:
await s.stream.close()
s.closeEvent.fire()
s.isClosed = true
trace "connection closed", closed = s.closed
proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} = proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} =
## read lenght prefixed msg ## read lenght prefixed msg
@ -100,21 +124,23 @@ proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} =
raise newInvalidVarintException() raise newInvalidVarintException()
result.setLen(size) result.setLen(size)
if size > 0.uint: if size > 0.uint:
trace "reading exact bytes from stream", size = size
await s.readExactly(addr result[0], int(size)) await s.readExactly(addr result[0], int(size))
except LPStreamIncompleteError, LPStreamReadError: except LPStreamIncompleteError as exc:
trace "remote connection closed", exc = getCurrentExceptionMsg() trace "remote connection ended unexpectedly", exc = exc.msg
except LPStreamReadError as exc:
trace "couldn't read from stream", exc = exc.msg
proc writeLp*(s: Connection, msg: string | seq[byte]): Future[void] {.gcsafe.} = proc writeLp*(s: Connection, msg: string | seq[byte]): Future[void] {.gcsafe.} =
## write lenght prefixed ## write lenght prefixed
var buf = initVBuffer() var buf = initVBuffer()
buf.writeSeq(msg) buf.writeSeq(msg)
buf.finish() buf.finish()
result = s.write(buf.buffer) s.write(buf.buffer)
method getObservedAddrs*(c: Connection): Future[MultiAddress] {.base, async, gcsafe.} = method getObservedAddrs*(c: Connection): Future[MultiAddress] {.base, async, gcsafe.} =
## get resolved multiaddresses for the connection ## get resolved multiaddresses for the connection
result = c.observedAddrs result = c.observedAddrs
proc `$`*(conn: Connection): string = proc `$`*(conn: Connection): string =
if conn.peerInfo.peerId.isSome: result = $(conn.peerInfo)
result = $(conn.peerInfo.peerId.get())

View File

@ -855,7 +855,7 @@ proc connect*(api: DaemonAPI, peer: PeerID,
timeout)) timeout))
pb.withMessage() do: pb.withMessage() do:
discard discard
finally: except:
await api.closeConnection(transp) await api.closeConnection(transp)
proc disconnect*(api: DaemonAPI, peer: PeerID) {.async.} = proc disconnect*(api: DaemonAPI, peer: PeerID) {.async.} =

View File

@ -7,12 +7,12 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import sequtils, strutils, strformat import strutils
import chronos, chronicles import chronos, chronicles
import connection, import connection,
varint,
vbuffer, vbuffer,
protocols/protocol protocols/protocol,
stream/lpstream
logScope: logScope:
topic = "Multistream" topic = "Multistream"
@ -56,16 +56,16 @@ proc select*(m: MultisteamSelect,
trace "selecting proto", proto = proto trace "selecting proto", proto = proto
await conn.writeLp((proto[0] & "\n")) # select proto await conn.writeLp((proto[0] & "\n")) # select proto
result = cast[string](await conn.readLp()) # read ms header result = cast[string]((await conn.readLp())) # read ms header
result.removeSuffix("\n") result.removeSuffix("\n")
if result != Codec: if result != Codec:
trace "handshake failed", codec = result trace "handshake failed", codec = result.toHex()
return "" return ""
if proto.len() == 0: # no protocols, must be a handshake call if proto.len() == 0: # no protocols, must be a handshake call
return return
result = cast[string](await conn.readLp()) # read the first proto result = cast[string]((await conn.readLp())) # read the first proto
trace "reading first requested proto" trace "reading first requested proto"
result.removeSuffix("\n") result.removeSuffix("\n")
if result == proto[0]: if result == proto[0]:
@ -76,7 +76,7 @@ proc select*(m: MultisteamSelect,
trace "selecting one of several protos" trace "selecting one of several protos"
for p in proto[1..<proto.len()]: for p in proto[1..<proto.len()]:
await conn.writeLp((p & "\n")) # select proto await conn.writeLp((p & "\n")) # select proto
result = cast[string](await conn.readLp()) # read the first proto result = cast[string]((await conn.readLp())) # read the first proto
result.removeSuffix("\n") result.removeSuffix("\n")
if result == p: if result == p:
trace "selected protocol", protocol = result trace "selected protocol", protocol = result
@ -102,7 +102,7 @@ proc list*(m: MultisteamSelect,
await conn.write(m.ls) # send ls await conn.write(m.ls) # send ls
var list = newSeq[string]() var list = newSeq[string]()
let ms = cast[string](await conn.readLp()) let ms = cast[string]((await conn.readLp()))
for s in ms.split("\n"): for s in ms.split("\n"):
if s.len() > 0: if s.len() > 0:
list.add(s) list.add(s)
@ -111,8 +111,10 @@ proc list*(m: MultisteamSelect,
proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} = proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} =
trace "handle: starting multistream handling" trace "handle: starting multistream handling"
while not conn.closed: try:
var ms = cast[string](await conn.readLp()) while not conn.closed:
await sleepAsync(1.millis)
var ms = cast[string]((await conn.readLp()))
ms.removeSuffix("\n") ms.removeSuffix("\n")
trace "handle: got request for ", ms trace "handle: got request for ", ms
@ -142,11 +144,15 @@ proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} =
try: try:
await h.protocol.handler(conn, ms) await h.protocol.handler(conn, ms)
return return
except Exception as exc: except CatchableError as exc:
warn "exception while handling ", msg = exc.msg warn "exception while handling", msg = exc.msg
return return
warn "no handlers for ", protocol = ms warn "no handlers for ", protocol = ms
await conn.write(m.na) await conn.write(m.na)
except CatchableError as exc:
trace "exception occured", exc = exc.msg
finally:
trace "leaving multistream loop"
proc addHandler*[T: LPProtocol](m: MultisteamSelect, proc addHandler*[T: LPProtocol](m: MultisteamSelect,
codec: string, codec: string,

View File

@ -7,7 +7,7 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import chronos, options, sequtils, strformat import chronos, options
import nimcrypto/utils, chronicles import nimcrypto/utils, chronicles
import types, import types,
../../connection, ../../connection,
@ -29,31 +29,33 @@ proc readMplexVarint(conn: Connection): Future[Option[uint]] {.async, gcsafe.} =
varint: uint varint: uint
length: int length: int
res: VarintStatus res: VarintStatus
var buffer = newSeq[byte](10) buffer = newSeq[byte](10)
result = none(uint) result = none(uint)
try: try:
for i in 0..<len(buffer): for i in 0..<len(buffer):
await conn.readExactly(addr buffer[i], 1) if not conn.closed:
res = LP.getUVarint(buffer.toOpenArray(0, i), length, varint) await conn.readExactly(addr buffer[i], 1)
if res == VarintStatus.Success: res = PB.getUVarint(buffer.toOpenArray(0, i), length, varint)
return some(varint) if res == VarintStatus.Success:
return some(varint)
if res != VarintStatus.Success: if res != VarintStatus.Success:
raise newInvalidVarintException() raise newInvalidVarintException()
except LPStreamIncompleteError: except LPStreamIncompleteError as exc:
trace "unable to read varint", exc = getCurrentExceptionMsg() trace "unable to read varint", exc = exc.msg
proc readMsg*(conn: Connection): Future[Option[Msg]] {.async, gcsafe.} = proc readMsg*(conn: Connection): Future[Option[Msg]] {.async, gcsafe.} =
let headerVarint = await conn.readMplexVarint() let headerVarint = await conn.readMplexVarint()
if headerVarint.isNone: if headerVarint.isNone:
return return
trace "readMsg: read header varint ", varint = headerVarint trace "read header varint", varint = headerVarint
let dataLenVarint = await conn.readMplexVarint() let dataLenVarint = await conn.readMplexVarint()
var data: seq[byte] var data: seq[byte]
if dataLenVarint.isSome and dataLenVarint.get() > 0.uint: if dataLenVarint.isSome and dataLenVarint.get() > 0.uint:
trace "readMsg: read size varint ", varint = dataLenVarint
data = await conn.read(dataLenVarint.get().int) data = await conn.read(dataLenVarint.get().int)
trace "read size varint", varint = dataLenVarint
let header = headerVarint.get() let header = headerVarint.get()
result = some((header shr 3, MessageType(header and 0x7), data)) result = some((header shr 3, MessageType(header and 0x7), data))
@ -64,11 +66,13 @@ proc writeMsg*(conn: Connection,
data: seq[byte] = @[]) {.async, gcsafe.} = data: seq[byte] = @[]) {.async, gcsafe.} =
## write lenght prefixed ## write lenght prefixed
var buf = initVBuffer() var buf = initVBuffer()
let header = (id shl 3 or ord(msgType).uint) buf.writePBVarint(id shl 3 or ord(msgType).uint)
buf.writeVarint(id shl 3 or ord(msgType).uint) buf.writePBVarint(data.len().uint) # size should be always sent
buf.writeVarint(data.len().uint) # size should be always sent
buf.finish() buf.finish()
await conn.write(buf.buffer & data) try:
await conn.write(buf.buffer & data)
except LPStreamIncompleteError as exc:
trace "unable to send message", exc = exc.msg
proc writeMsg*(conn: Connection, proc writeMsg*(conn: Connection,
id: uint, id: uint,

View File

@ -7,7 +7,6 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import strformat
import chronos, chronicles import chronos, chronicles
import types, import types,
coder, coder,
@ -52,99 +51,110 @@ proc newChannel*(id: uint,
result.asyncLock = newAsyncLock() result.asyncLock = newAsyncLock()
let chan = result let chan = result
proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} = proc writeHandler(data: seq[byte]): Future[void] {.async.} =
# writes should happen in sequence # writes should happen in sequence
await chan.asyncLock.acquire() await chan.asyncLock.acquire()
trace "writeHandler: sending data ", data = data.toHex(), id = chan.id trace "sending data ", data = data.toHex(),
id = chan.id,
initiator = chan.initiator
await conn.writeMsg(chan.id, chan.msgCode, data) # write header await conn.writeMsg(chan.id, chan.msgCode, data) # write header
chan.asyncLock.release() chan.asyncLock.release()
result.initBufferStream(writeHandler, size) result.initBufferStream(writeHandler, size)
proc closeMessage(s: LPChannel) {.async, gcsafe.} = proc closeMessage(s: LPChannel) {.async.} =
await s.conn.writeMsg(s.id, s.closeCode) # write header await s.conn.writeMsg(s.id, s.closeCode) # write header
proc closed*(s: LPChannel): bool =
s.closedLocal and s.closedLocal
proc closedByRemote*(s: LPChannel) {.async.} = proc closedByRemote*(s: LPChannel) {.async.} =
s.closedRemote = true s.closedRemote = true
proc cleanUp*(s: LPChannel): Future[void] = proc cleanUp*(s: LPChannel): Future[void] =
# method which calls the underlying buffer's `close`
# method used instead of `close` since it's overloaded to
# simulate half-closed streams
result = procCall close(BufferStream(s)) result = procCall close(BufferStream(s))
proc open*(s: LPChannel): Future[void] =
s.conn.writeMsg(s.id, MessageType.New, s.name)
method close*(s: LPChannel) {.async, gcsafe.} = method close*(s: LPChannel) {.async, gcsafe.} =
s.closedLocal = true s.closedLocal = true
await s.closeMessage() await s.closeMessage()
proc resetMessage(s: LPChannel) {.async, gcsafe.} = proc resetMessage(s: LPChannel) {.async.} =
await s.conn.writeMsg(s.id, s.resetCode) await s.conn.writeMsg(s.id, s.resetCode)
proc resetByRemote*(s: LPChannel) {.async, gcsafe.} = proc resetByRemote*(s: LPChannel) {.async.} =
await allFutures(s.close(), s.closedByRemote()) await allFutures(s.close(), s.closedByRemote())
s.isReset = true s.isReset = true
proc reset*(s: LPChannel) {.async.} = proc reset*(s: LPChannel) {.async.} =
await allFutures(s.resetMessage(), s.resetByRemote()) await allFutures(s.resetMessage(), s.resetByRemote())
proc isReadEof(s: LPChannel): bool = method closed*(s: LPChannel): bool =
bool((s.closedRemote or s.closedLocal) and s.len() < 1) result = s.closedRemote and s.len == 0
proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] {.gcsafe.} = proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] =
if s.closedRemote: if s.closedRemote or s.isReset:
raise newLPStreamClosedError() raise newLPStreamClosedError()
trace "pushing data to channel", data = data.toHex(),
id = s.id,
initiator = s.initiator
result = procCall pushTo(BufferStream(s), data) result = procCall pushTo(BufferStream(s), data)
method read*(s: LPChannel, n = -1): Future[seq[byte]] {.gcsafe.} = method read*(s: LPChannel, n = -1): Future[seq[byte]] =
if s.isReadEof(): if s.closed or s.isReset:
raise newLPStreamClosedError() raise newLPStreamClosedError()
result = procCall read(BufferStream(s), n) result = procCall read(BufferStream(s), n)
method readExactly*(s: LPChannel, method readExactly*(s: LPChannel,
pbytes: pointer, pbytes: pointer,
nbytes: int): nbytes: int):
Future[void] {.gcsafe.} = Future[void] =
if s.isReadEof(): if s.closed or s.isReset:
raise newLPStreamClosedError() raise newLPStreamClosedError()
result = procCall readExactly(BufferStream(s), pbytes, nbytes) result = procCall readExactly(BufferStream(s), pbytes, nbytes)
method readLine*(s: LPChannel, method readLine*(s: LPChannel,
limit = 0, limit = 0,
sep = "\r\n"): sep = "\r\n"):
Future[string] {.gcsafe.} = Future[string] =
if s.isReadEof(): if s.closed or s.isReset:
raise newLPStreamClosedError() raise newLPStreamClosedError()
result = procCall readLine(BufferStream(s), limit, sep) result = procCall readLine(BufferStream(s), limit, sep)
method readOnce*(s: LPChannel, method readOnce*(s: LPChannel,
pbytes: pointer, pbytes: pointer,
nbytes: int): nbytes: int):
Future[int] {.gcsafe.} = Future[int] =
if s.isReadEof(): if s.closed or s.isReset:
raise newLPStreamClosedError() raise newLPStreamClosedError()
result = procCall readOnce(BufferStream(s), pbytes, nbytes) result = procCall readOnce(BufferStream(s), pbytes, nbytes)
method readUntil*(s: LPChannel, method readUntil*(s: LPChannel,
pbytes: pointer, nbytes: int, pbytes: pointer, nbytes: int,
sep: seq[byte]): sep: seq[byte]):
Future[int] {.gcsafe.} = Future[int] =
if s.isReadEof(): if s.closed or s.isReset:
raise newLPStreamClosedError() raise newLPStreamClosedError()
result = procCall readOnce(BufferStream(s), pbytes, nbytes) result = procCall readOnce(BufferStream(s), pbytes, nbytes)
method write*(s: LPChannel, method write*(s: LPChannel,
pbytes: pointer, pbytes: pointer,
nbytes: int): Future[void] {.gcsafe.} = nbytes: int): Future[void] =
if s.closedLocal: if s.closedLocal or s.isReset:
raise newLPStreamClosedError() raise newLPStreamClosedError()
result = procCall write(BufferStream(s), pbytes, nbytes) result = procCall write(BufferStream(s), pbytes, nbytes)
method write*(s: LPChannel, msg: string, msglen = -1) {.async, gcsafe.} = method write*(s: LPChannel, msg: string, msglen = -1) {.async.} =
if s.closedLocal: if s.closedLocal or s.isReset:
raise newLPStreamClosedError() raise newLPStreamClosedError()
result = procCall write(BufferStream(s), msg, msglen) result = procCall write(BufferStream(s), msg, msglen)
method write*(s: LPChannel, msg: seq[byte], msglen = -1) {.async, gcsafe.} = method write*(s: LPChannel, msg: seq[byte], msglen = -1) {.async.} =
if s.closedLocal: if s.closedLocal or s.isReset:
raise newLPStreamClosedError() raise newLPStreamClosedError()
result = procCall write(BufferStream(s), msg, msglen) result = procCall write(BufferStream(s), msg, msglen)

View File

@ -11,16 +11,14 @@
## Timeouts and message limits are still missing ## Timeouts and message limits are still missing
## they need to be added ASAP ## they need to be added ASAP
import tables, sequtils, options, strformat import tables, sequtils, options
import chronos, chronicles import chronos, chronicles
import coder, types, lpchannel, import ../muxer,
../muxer,
../../varint,
../../connection, ../../connection,
../../vbuffer, ../../stream/lpstream,
../../protocols/protocol, coder,
../../stream/bufferstream, types,
../../stream/lpstream lpchannel
logScope: logScope:
topic = "Mplex" topic = "Mplex"
@ -34,9 +32,11 @@ type
proc getChannelList(m: Mplex, initiator: bool): var Table[uint, LPChannel] = proc getChannelList(m: Mplex, initiator: bool): var Table[uint, LPChannel] =
if initiator: if initiator:
result = m.remote trace "picking local channels", initiator = initiator
else:
result = m.local result = m.local
else:
trace "picking remote channels", initiator = initiator
result = m.remote
proc newStreamInternal*(m: Mplex, proc newStreamInternal*(m: Mplex,
initiator: bool = true, initiator: bool = true,
@ -45,17 +45,28 @@ proc newStreamInternal*(m: Mplex,
Future[LPChannel] {.async, gcsafe.} = Future[LPChannel] {.async, gcsafe.} =
## create new channel/stream ## create new channel/stream
let id = if initiator: m.currentId.inc(); m.currentId else: chanId let id = if initiator: m.currentId.inc(); m.currentId else: chanId
trace "creating new channel", channelId = id, initiator = initiator
result = newChannel(id, m.connection, initiator, name) result = newChannel(id, m.connection, initiator, name)
m.getChannelList(initiator)[id] = result m.getChannelList(initiator)[id] = result
proc cleanupChann(m: Mplex, chann: LPChannel, initiator: bool) {.async, inline.} =
## call the channel's `close` to signal the
## remote that the channel is closing
if not isNil(chann) and not chann.closed:
await chann.close()
await chann.cleanUp()
m.getChannelList(initiator).del(chann.id)
trace "cleaned up channel", id = chann.id
method handle*(m: Mplex) {.async, gcsafe.} = method handle*(m: Mplex) {.async, gcsafe.} =
trace "starting mplex main loop" trace "starting mplex main loop"
try: try:
while not m.connection.closed: while not m.connection.closed:
trace "waiting for data"
let msg = await m.connection.readMsg() let msg = await m.connection.readMsg()
if msg.isNone: if msg.isNone:
# TODO: allow poll with timeout to avoid using `sleepAsync` # TODO: allow poll with timeout to avoid using `sleepAsync`
await sleepAsync(10.millis) await sleepAsync(1.millis)
continue continue
let (id, msgType, data) = msg.get() let (id, msgType, data) = msg.get()
@ -63,8 +74,11 @@ method handle*(m: Mplex) {.async, gcsafe.} =
var channel: LPChannel var channel: LPChannel
if MessageType(msgType) != MessageType.New: if MessageType(msgType) != MessageType.New:
let channels = m.getChannelList(initiator) let channels = m.getChannelList(initiator)
if not channels.contains(id): if id notin channels:
trace "handle: Channel with id and msg type ", id = id, msg = msgType trace "Channel not found, skipping", id = id,
initiator = initiator,
msg = msgType
await sleepAsync(1.millis)
continue continue
channel = channels[id] channel = channels[id]
@ -72,36 +86,44 @@ method handle*(m: Mplex) {.async, gcsafe.} =
of MessageType.New: of MessageType.New:
let name = cast[string](data) let name = cast[string](data)
channel = await m.newStreamInternal(false, id, name) channel = await m.newStreamInternal(false, id, name)
trace "handle: created channel ", id = id, name = name trace "created channel", id = id, name = name, inititator = true
if not isNil(m.streamHandler): if not isNil(m.streamHandler):
let stream = newConnection(channel) let stream = newConnection(channel)
stream.peerInfo = m.connection.peerInfo stream.peerInfo = m.connection.peerInfo
let handlerFut = m.streamHandler(stream)
# channel cleanup routine # cleanup channel once handler is finished
proc cleanUpChan(udata: pointer) {.gcsafe.} = # stream.closeEvent.wait().addCallback(
if handlerFut.finished: # proc(udata: pointer) =
channel.close().addCallback( # asyncCheck cleanupChann(m, channel, initiator))
proc(udata: pointer) =
channel.cleanUp() asyncCheck m.streamHandler(stream)
.addCallback(proc(udata: pointer) =
trace "handle: cleaned up channel ", id = id))
handlerFut.addCallback(cleanUpChan)
continue continue
of MessageType.MsgIn, MessageType.MsgOut: of MessageType.MsgIn, MessageType.MsgOut:
trace "handle: pushing data to channel ", id = id, msgType = msgType trace "pushing data to channel", id = id,
initiator = initiator,
msgType = msgType
await channel.pushTo(data) await channel.pushTo(data)
of MessageType.CloseIn, MessageType.CloseOut: of MessageType.CloseIn, MessageType.CloseOut:
trace "handle: closing channel ", id = id, msgType = msgType trace "closing channel", id = id,
initiator = initiator,
msgType = msgType
await channel.closedByRemote() await channel.closedByRemote()
m.getChannelList(initiator).del(id) m.getChannelList(initiator).del(id)
of MessageType.ResetIn, MessageType.ResetOut: of MessageType.ResetIn, MessageType.ResetOut:
trace "handle: resetting channel ", id = id trace "resetting channel", id = id,
initiator = initiator,
msgType = msgType
await channel.resetByRemote() await channel.resetByRemote()
m.getChannelList(initiator).del(id)
break break
except: except CatchableError as exc:
error "exception occurred", exception = getCurrentExceptionMsg() trace "exception occurred", exception = exc.msg
finally: finally:
trace "stopping mplex main loop"
await m.connection.close() await m.connection.close()
proc newMplex*(conn: Connection, proc newMplex*(conn: Connection,
@ -112,13 +134,20 @@ proc newMplex*(conn: Connection,
result.remote = initTable[uint, LPChannel]() result.remote = initTable[uint, LPChannel]()
result.local = initTable[uint, LPChannel]() result.local = initTable[uint, LPChannel]()
let m = result
conn.closeEvent.wait().addCallback(
proc(udata: pointer) =
asyncCheck m.close()
)
method newStream*(m: Mplex, name: string = ""): Future[Connection] {.async, gcsafe.} = method newStream*(m: Mplex, name: string = ""): Future[Connection] {.async, gcsafe.} =
let channel = await m.newStreamInternal() let channel = await m.newStreamInternal()
await m.connection.writeMsg(channel.id, MessageType.New, name) # TODO: open the channel (this should be lazy)
await channel.open()
result = newConnection(channel) result = newConnection(channel)
result.peerInfo = m.connection.peerInfo result.peerInfo = m.connection.peerInfo
method close*(m: Mplex) {.async, gcsafe.} = method close*(m: Mplex) {.async, gcsafe.} =
await allFutures(@[allFutures(toSeq(m.remote.values).mapIt(it.close())), trace "closing mplex muxer"
allFutures(toSeq(m.local.values).mapIt(it.close()))]) await allFutures(@[allFutures(toSeq(m.remote.values).mapIt(it.reset())),
m.connection.reset() allFutures(toSeq(m.local.values).mapIt(it.reset()))])

View File

@ -8,7 +8,6 @@
## those terms. ## those terms.
import chronos import chronos
import ../../connection
const MaxMsgSize* = 1 shl 20 # 1mb const MaxMsgSize* = 1 shl 20 # 1mb
const MaxChannels* = 1000 const MaxChannels* = 1000

View File

@ -10,7 +10,27 @@
import options import options
import peer, multiaddress import peer, multiaddress
type PeerInfo* = object of RootObj type
peerId*: Option[PeerID] PeerInfo* = object of RootObj
addrs*: seq[MultiAddress] peerId*: Option[PeerID]
protocols*: seq[string] addrs*: seq[MultiAddress]
protocols*: seq[string]
proc id*(p: PeerInfo): string =
if p.peerId.isSome:
result = p.peerId.get().pretty
proc `$`*(p: PeerInfo): string =
if p.peerId.isSome:
result.add("PeerID: ")
result.add(p.id & "\n")
if p.addrs.len > 0:
result.add("Peer Addrs: ")
for a in p.addrs:
result.add($a & "\n")
if p.protocols.len > 0:
result.add("Protocols: ")
for proto in p.protocols:
result.add(proto & "\n")

View File

@ -7,7 +7,7 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import options, strformat import options
import chronos, chronicles import chronos, chronicles
import ../protobuf/minprotobuf, import ../protobuf/minprotobuf,
../peerinfo, ../peerinfo,
@ -115,14 +115,14 @@ method init*(p: Identify) =
trace "handling identify request" trace "handling identify request"
var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs()) var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs())
await conn.writeLp(pb.buffer) await conn.writeLp(pb.buffer)
# await conn.close() #TODO: investigate why this breaks
p.handler = handle p.handler = handle
p.codec = IdentifyCodec p.codec = IdentifyCodec
proc identify*(p: Identify, proc identify*(p: Identify,
conn: Connection, conn: Connection,
remotePeerInfo: PeerInfo): remotePeerInfo: PeerInfo): Future[IdentifyInfo] {.async, gcsafe.} =
Future[IdentifyInfo] {.async.} =
var message = await conn.readLp() var message = await conn.readLp()
if len(message) == 0: if len(message) == 0:
trace "identify: Invalid or empty message received!" trace "identify: Invalid or empty message received!"
@ -139,7 +139,7 @@ proc identify*(p: Identify,
if peer != remotePeerInfo.peerId.get(): if peer != remotePeerInfo.peerId.get():
trace "Peer ids don't match", trace "Peer ids don't match",
remote = peer.pretty(), remote = peer.pretty(),
local = remotePeerInfo.peerId.get().pretty() local = remotePeerInfo.id
raise newException(IdentityNoMatchError, raise newException(IdentityNoMatchError,
"Peer ids don't match") "Peer ids don't match")
@ -149,5 +149,4 @@ proc identify*(p: Identify,
proc push*(p: Identify, conn: Connection) {.async.} = proc push*(p: Identify, conn: Connection) {.async.} =
await conn.write(IdentifyPushCodec) await conn.write(IdentifyPushCodec)
var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs()) var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs())
let length = pb.getLen()
await conn.writeLp(pb.buffer) await conn.writeLp(pb.buffer)

View File

@ -8,9 +8,7 @@
## those terms. ## those terms.
import chronos import chronos
import ../connection, import ../connection
../peerinfo,
../multiaddress
type type
LPProtoHandler* = proc (conn: Connection, LPProtoHandler* = proc (conn: Connection,

View File

@ -14,6 +14,7 @@ import rpcmsg,
../../peer, ../../peer,
../../peerinfo, ../../peerinfo,
../../connection, ../../connection,
../../stream/lpstream,
../../crypto/crypto, ../../crypto/crypto,
../../protobuf/minprotobuf ../../protobuf/minprotobuf
@ -45,7 +46,7 @@ proc handle*(p: PubSubPeer) {.async, gcsafe.} =
trace "Decoded msg from peer", peer = p.id, msg = msg trace "Decoded msg from peer", peer = p.id, msg = msg
await p.handler(p, @[msg]) await p.handler(p, @[msg])
except: except:
error "An exception occured while processing pubsub rpc requests", exc = getCurrentExceptionMsg() trace "An exception occured while processing pubsub rpc requests", exc = getCurrentExceptionMsg()
finally: finally:
trace "closing connection to pubsub peer", peer = p.id trace "closing connection to pubsub peer", peer = p.id
await p.conn.close() await p.conn.close()

View File

@ -8,8 +8,7 @@
## those terms. ## those terms.
import chronos import chronos
import secure, import secure, ../../connection
../../connection
const PlainTextCodec* = "/plaintext/1.0.0" const PlainTextCodec* = "/plaintext/1.0.0"

View File

@ -6,10 +6,12 @@
## at your option. ## at your option.
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import options
import chronos, chronicles import chronos, chronicles
import nimcrypto/[sysrand, hmac, sha2, sha, hash, rijndael, twofish, bcmode] import nimcrypto/[sysrand, hmac, sha2, sha, hash, rijndael, twofish, bcmode]
import secure, import secure,
../../connection, ../../connection,
../../stream/lpstream,
../../crypto/crypto, ../../crypto/crypto,
../../crypto/ecnist, ../../crypto/ecnist,
../../protobuf/minprotobuf, ../../protobuf/minprotobuf,
@ -60,7 +62,6 @@ type
ctxsha1: HMAC[sha1] ctxsha1: HMAC[sha1]
SecureConnection* = ref object of Connection SecureConnection* = ref object of Connection
conn*: Connection
writerMac: SecureMac writerMac: SecureMac
readerMac: SecureMac readerMac: SecureMac
writerCoder: SecureCipher writerCoder: SecureCipher
@ -176,13 +177,13 @@ proc readMessage*(sconn: SecureConnection): Future[seq[byte]] {.async.} =
## Read message from channel secure connection ``sconn``. ## Read message from channel secure connection ``sconn``.
try: try:
var buf = newSeq[byte](4) var buf = newSeq[byte](4)
await sconn.conn.readExactly(addr buf[0], 4) await sconn.readExactly(addr buf[0], 4)
let length = (int(buf[0]) shl 24) or (int(buf[1]) shl 16) or let length = (int(buf[0]) shl 24) or (int(buf[1]) shl 16) or
(int(buf[2]) shl 8) or (int(buf[3])) (int(buf[2]) shl 8) or (int(buf[3]))
trace "Recieved message header", header = toHex(buf), length = length trace "Recieved message header", header = toHex(buf), length = length
if length <= SecioMaxMessageSize: if length <= SecioMaxMessageSize:
buf.setLen(length) buf.setLen(length)
await sconn.conn.readExactly(addr buf[0], length) await sconn.readExactly(addr buf[0], length)
trace "Received message body", length = length, trace "Received message body", length = length,
buffer = toHex(buf) buffer = toHex(buf)
if sconn.macCheckAndDecode(buf): if sconn.macCheckAndDecode(buf):
@ -213,21 +214,27 @@ proc writeMessage*(sconn: SecureConnection, message: seq[byte]) {.async.} =
msg[3] = byte(length and 0xFF) msg[3] = byte(length and 0xFF)
trace "Writing message", message = toHex(msg) trace "Writing message", message = toHex(msg)
try: try:
await sconn.conn.write(msg) await sconn.write(msg)
except AsyncStreamWriteError: except AsyncStreamWriteError:
trace "Could not write to connection" trace "Could not write to connection"
proc newSecureConnection*(conn: Connection, hash: string, cipher: string, proc newSecureConnection*(conn: Connection,
hash: string,
cipher: string,
secrets: Secret, secrets: Secret,
order: int): SecureConnection = order: int,
peerId: PeerID): SecureConnection =
## Create new secure connection, using specified hash algorithm ``hash``, ## Create new secure connection, using specified hash algorithm ``hash``,
## cipher algorithm ``cipher``, stretched keys ``secrets`` and order ## cipher algorithm ``cipher``, stretched keys ``secrets`` and order
## ``order``. ## ``order``.
new result new result
result.stream = conn
result.closeEvent = newAsyncEvent()
let i0 = if order < 0: 1 else: 0 let i0 = if order < 0: 1 else: 0
let i1 = if order < 0: 0 else: 1 let i1 = if order < 0: 0 else: 1
result.conn = conn
trace "Writer credentials", mackey = toHex(secrets.macOpenArray(i0)), trace "Writer credentials", mackey = toHex(secrets.macOpenArray(i0)),
enckey = toHex(secrets.keyOpenArray(i0)), enckey = toHex(secrets.keyOpenArray(i0)),
iv = toHex(secrets.ivOpenArray(i0)) iv = toHex(secrets.ivOpenArray(i0))
@ -241,6 +248,8 @@ proc newSecureConnection*(conn: Connection, hash: string, cipher: string,
result.readerCoder.init(cipher, secrets.keyOpenArray(i1), result.readerCoder.init(cipher, secrets.keyOpenArray(i1),
secrets.ivOpenArray(i1)) secrets.ivOpenArray(i1))
result.peerInfo.peerId = some(peerId)
proc transactMessage(conn: Connection, proc transactMessage(conn: Connection,
msg: seq[byte]): Future[seq[byte]] {.async.} = msg: seq[byte]): Future[seq[byte]] {.async.} =
var buf = newSeq[byte](4) var buf = newSeq[byte](4)
@ -281,7 +290,6 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.}
remoteHashes: string remoteHashes: string
remotePeerId: PeerID remotePeerId: PeerID
localPeerId: PeerID localPeerId: PeerID
ekey: PrivateKey
localBytesPubkey = s.localPublicKey.getBytes() localBytesPubkey = s.localPublicKey.getBytes()
if randomBytes(localNonce) != SecioNonceSize: if randomBytes(localNonce) != SecioNonceSize:
@ -388,7 +396,8 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.}
# Perform Nonce exchange over encrypted channel. # Perform Nonce exchange over encrypted channel.
result = newSecureConnection(conn, hash, cipher, keys, order) result = newSecureConnection(conn, hash, cipher, keys, order, remotePeerId)
await result.writeMessage(remoteNonce) await result.writeMessage(remoteNonce)
var res = await result.readMessage() var res = await result.readMessage()
@ -400,17 +409,21 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.}
trace "Secure handshake succeeded" trace "Secure handshake succeeded"
proc readLoop(sconn: SecureConnection, stream: BufferStream) {.async.} = proc readLoop(sconn: SecureConnection, stream: BufferStream) {.async.} =
while not sconn.conn.closed: try:
try: while not sconn.closed:
let msg = await sconn.readMessage() let msg = await sconn.readMessage()
await stream.pushTo(msg) if msg.len > 0:
except CatchableError as exc: await stream.pushTo(msg)
trace "exception in secio", exc = exc.msg
return
finally:
trace "ending secio readLoop"
proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async.} = # tight loop, give a chance for other
# stuff to run as well
await sleepAsync(1.millis)
except CatchableError as exc:
trace "exception occured", exc = exc.msg
finally:
trace "ending secio readLoop", isclosed = sconn.closed()
proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async, gcsafe.} =
var sconn = await s.handshake(conn) var sconn = await s.handshake(conn)
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = proc writeHandler(data: seq[byte]) {.async, gcsafe.} =
trace "sending encrypted bytes", bytes = data.toHex() trace "sending encrypted bytes", bytes = data.toHex()
@ -419,7 +432,13 @@ proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async.} =
var stream = newBufferStream(writeHandler) var stream = newBufferStream(writeHandler)
asyncCheck readLoop(sconn, stream) asyncCheck readLoop(sconn, stream)
var secured = newConnection(stream) var secured = newConnection(stream)
secured.peerInfo = sconn.conn.peerInfo secured.closeEvent.wait()
.addCallback(proc(udata: pointer) =
trace "wrapped connection closed, closing upstream"
if not sconn.closed:
asyncCheck sconn.close()
)
secured.peerInfo.peerId = sconn.peerInfo.peerId
result = secured result = secured
method init(s: Secio) {.gcsafe.} = method init(s: Secio) {.gcsafe.} =

View File

@ -25,12 +25,12 @@
## ordered and asynchronous. Reads are queued up in order ## ordered and asynchronous. Reads are queued up in order
## and are suspended when not enough data available. This ## and are suspended when not enough data available. This
## allows preserving backpressure while maintaining full ## allows preserving backpressure while maintaining full
## asynchrony. Both writting to the internal buffer with ## asynchrony. Both writting to the internal buffer with
## ``pushTo`` as well as reading with ``read*` methods, ## ``pushTo`` as well as reading with ``read*` methods,
## will suspend until either the amount of elements in the ## will suspend until either the amount of elements in the
## buffer goes below ``maxSize`` or more data becomes available. ## buffer goes below ``maxSize`` or more data becomes available.
import deques, tables, sequtils, math import deques, math
import chronos import chronos
import ../stream/lpstream import ../stream/lpstream
@ -38,14 +38,25 @@ const DefaultBufferSize* = 1024
type type
# TODO: figure out how to make this generic to avoid casts # TODO: figure out how to make this generic to avoid casts
WriteHandler* = proc (data: seq[byte]): Future[void] {.gcsafe.} WriteHandler* = proc (data: seq[byte]): Future[void]
BufferStream* = ref object of LPStream BufferStream* = ref object of LPStream
maxSize*: int # buffer's max size in bytes maxSize*: int # buffer's max size in bytes
readBuf: Deque[byte] # a deque is based on a ring buffer readBuf: Deque[byte] # this is a ring buffer based dequeue, this makes it perfect as the backing store here
readReqs: Deque[Future[void]] # use dequeue to fire reads in order readReqs: Deque[Future[void]] # use dequeue to fire reads in order
dataReadEvent: AsyncEvent dataReadEvent: AsyncEvent
writeHandler*: WriteHandler writeHandler*: WriteHandler
lock: AsyncLock
isPiped: bool
AlreadyPipedError* = object of CatchableError
NotWritableError* = object of CatchableError
proc newAlreadyPipedError*(): ref Exception {.inline.} =
result = newException(AlreadyPipedError, "stream already piped")
proc newNotWritableError*(): ref Exception {.inline.} =
result = newException(NotWritableError, "stream is not writable")
proc requestReadBytes(s: BufferStream): Future[void] = proc requestReadBytes(s: BufferStream): Future[void] =
## create a future that will complete when more ## create a future that will complete when more
@ -53,14 +64,19 @@ proc requestReadBytes(s: BufferStream): Future[void] =
result = newFuture[void]() result = newFuture[void]()
s.readReqs.addLast(result) s.readReqs.addLast(result)
proc initBufferStream*(s: BufferStream, handler: WriteHandler, size: int = DefaultBufferSize) = proc initBufferStream*(s: BufferStream,
handler: WriteHandler = nil,
size: int = DefaultBufferSize) =
s.maxSize = if isPowerOfTwo(size): size else: nextPowerOfTwo(size) s.maxSize = if isPowerOfTwo(size): size else: nextPowerOfTwo(size)
s.readBuf = initDeque[byte](s.maxSize) s.readBuf = initDeque[byte](s.maxSize)
s.readReqs = initDeque[Future[void]]() s.readReqs = initDeque[Future[void]]()
s.dataReadEvent = newAsyncEvent() s.dataReadEvent = newAsyncEvent()
s.lock = newAsyncLock()
s.writeHandler = handler s.writeHandler = handler
s.closeEvent = newAsyncEvent()
proc newBufferStream*(handler: WriteHandler, size: int = DefaultBufferSize): BufferStream = proc newBufferStream*(handler: WriteHandler = nil,
size: int = DefaultBufferSize): BufferStream =
new result new result
result.initBufferStream(handler, size) result.initBufferStream(handler, size)
@ -78,15 +94,24 @@ proc shrink(s: BufferStream, fromFirst = 0, fromLast = 0) =
proc len*(s: BufferStream): int = s.readBuf.len proc len*(s: BufferStream): int = s.readBuf.len
proc pushTo*(s: BufferStream, data: seq[byte]) {.async, gcsafe.} = proc pushTo*(s: BufferStream, data: seq[byte]) {.async.} =
## Write bytes to internal read buffer, use this to fill up the ## Write bytes to internal read buffer, use this to fill up the
## buffer with data. ## buffer with data.
## ##
## This method is async and will wait until all data has been ## This method is async and will wait until all data has been
## written to the internal buffer; this is done so that backpressure ## written to the internal buffer; this is done so that backpressure
## is preserved. ## is preserved.
##
await s.lock.acquire()
var index = 0 var index = 0
while true: while true:
# give readers a chance free up the buffer
# it it's full.
if s.readBuf.len >= s.maxSize:
await sleepAsync(10.millis)
while index < data.len and s.readBuf.len < s.maxSize: while index < data.len and s.readBuf.len < s.maxSize:
s.readBuf.addLast(data[index]) s.readBuf.addLast(data[index])
inc(index) inc(index)
@ -101,11 +126,13 @@ proc pushTo*(s: BufferStream, data: seq[byte]) {.async, gcsafe.} =
# if we couldn't transfer all the data to the # if we couldn't transfer all the data to the
# internal buf wait on a read event # internal buf wait on a read event
await s.dataReadEvent.wait() await s.dataReadEvent.wait()
s.lock.release()
method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async, gcsafe.} = method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async.} =
## Read all bytes (n <= 0) or exactly `n` bytes from buffer ## Read all bytes (n <= 0) or exactly `n` bytes from buffer
## ##
## This procedure allocates buffer seq[byte] and return it as result. ## This procedure allocates buffer seq[byte] and return it as result.
##
var size = if n > 0: n else: s.readBuf.len() var size = if n > 0: n else: s.readBuf.len()
var index = 0 var index = 0
while index < size: while index < size:
@ -119,22 +146,23 @@ method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async, gcsafe.} =
method readExactly*(s: BufferStream, method readExactly*(s: BufferStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): nbytes: int):
Future[void] {.async, gcsafe.} = Future[void] {.async.} =
## Read exactly ``nbytes`` bytes from read-only stream ``rstream`` and store ## Read exactly ``nbytes`` bytes from read-only stream ``rstream`` and store
## it to ``pbytes``. ## it to ``pbytes``.
## ##
## If EOF is received and ``nbytes`` is not yet read, the procedure ## If EOF is received and ``nbytes`` is not yet read, the procedure
## will raise ``LPStreamIncompleteError``. ## will raise ``LPStreamIncompleteError``.
let buff = await s.read(nbytes) ##
var buff = await s.read(nbytes)
if nbytes > buff.len(): if nbytes > buff.len():
raise newLPStreamIncompleteError() raise newLPStreamIncompleteError()
copyMem(pbytes, unsafeAddr buff[0], nbytes) copyMem(pbytes, addr buff[0], nbytes)
method readLine*(s: BufferStream, method readLine*(s: BufferStream,
limit = 0, limit = 0,
sep = "\r\n"): sep = "\r\n"):
Future[string] {.async, gcsafe.} = Future[string] {.async.} =
## Read one line from read-only stream ``rstream``, where ``"line"`` is a ## Read one line from read-only stream ``rstream``, where ``"line"`` is a
## sequence of bytes ending with ``sep`` (default is ``"\r\n"``). ## sequence of bytes ending with ``sep`` (default is ``"\r\n"``).
## ##
@ -146,6 +174,7 @@ method readLine*(s: BufferStream,
## ##
## If ``limit`` more then 0, then result string will be limited to ``limit`` ## If ``limit`` more then 0, then result string will be limited to ``limit``
## bytes. ## bytes.
##
result = "" result = ""
var lim = if limit <= 0: -1 else: limit var lim = if limit <= 0: -1 else: limit
var state = 0 var state = 0
@ -170,11 +199,12 @@ method readLine*(s: BufferStream,
method readOnce*(s: BufferStream, method readOnce*(s: BufferStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): nbytes: int):
Future[int] {.async, gcsafe.} = Future[int] {.async.} =
## Perform one read operation on read-only stream ``rstream``. ## Perform one read operation on read-only stream ``rstream``.
## ##
## If internal buffer is not empty, ``nbytes`` bytes will be transferred from ## If internal buffer is not empty, ``nbytes`` bytes will be transferred from
## internal buffer, otherwise it will wait until some bytes will be received. ## internal buffer, otherwise it will wait until some bytes will be received.
##
if s.readBuf.len == 0: if s.readBuf.len == 0:
await s.requestReadBytes() await s.requestReadBytes()
@ -186,7 +216,7 @@ method readUntil*(s: BufferStream,
pbytes: pointer, pbytes: pointer,
nbytes: int, nbytes: int,
sep: seq[byte]): sep: seq[byte]):
Future[int] {.async, gcsafe.} = Future[int] {.async.} =
## Read data from the read-only stream ``rstream`` until separator ``sep`` is ## Read data from the read-only stream ``rstream`` until separator ``sep`` is
## found. ## found.
## ##
@ -200,6 +230,7 @@ method readUntil*(s: BufferStream,
## will raise ``LPStreamLimitError``. ## will raise ``LPStreamLimitError``.
## ##
## Procedure returns actual number of bytes read. ## Procedure returns actual number of bytes read.
##
var var
dest = cast[ptr UncheckedArray[byte]](pbytes) dest = cast[ptr UncheckedArray[byte]](pbytes)
state = 0 state = 0
@ -233,20 +264,20 @@ method readUntil*(s: BufferStream,
method write*(s: BufferStream, method write*(s: BufferStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): Future[void] nbytes: int): Future[void] =
{.gcsafe.} =
## Consume (discard) all bytes (n <= 0) or ``n`` bytes from read-only stream ## Consume (discard) all bytes (n <= 0) or ``n`` bytes from read-only stream
## ``rstream``. ## ``rstream``.
## ##
## Return number of bytes actually consumed (discarded). ## Return number of bytes actually consumed (discarded).
##
var buf: seq[byte] = newSeq[byte](nbytes) var buf: seq[byte] = newSeq[byte](nbytes)
copyMem(addr buf[0], pbytes, nbytes) copyMem(addr buf[0], pbytes, nbytes)
result = s.writeHandler(buf) if not isNil(s.writeHandler):
result = s.writeHandler(buf)
method write*(s: BufferStream, method write*(s: BufferStream,
msg: string, msg: string,
msglen = -1): Future[void] msglen = -1): Future[void] =
{.gcsafe.} =
## Write string ``sbytes`` of length ``msglen`` to writer stream ``wstream``. ## Write string ``sbytes`` of length ``msglen`` to writer stream ``wstream``.
## ##
## String ``sbytes`` must not be zero-length. ## String ``sbytes`` must not be zero-length.
@ -254,14 +285,15 @@ method write*(s: BufferStream,
## If ``msglen < 0`` whole string ``sbytes`` will be writen to stream. ## If ``msglen < 0`` whole string ``sbytes`` will be writen to stream.
## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to ## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to
## stream. ## stream.
##
var buf = "" var buf = ""
shallowCopy(buf, if msglen > 0: msg[0..<msglen] else: msg) shallowCopy(buf, if msglen > 0: msg[0..<msglen] else: msg)
result = s.writeHandler(cast[seq[byte]](buf)) if not isNil(s.writeHandler):
result = s.writeHandler(cast[seq[byte]](buf))
method write*(s: BufferStream, method write*(s: BufferStream,
msg: seq[byte], msg: seq[byte],
msglen = -1): Future[void] msglen = -1): Future[void] =
{.gcsafe.} =
## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer ## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer
## stream ``wstream``. ## stream ``wstream``.
## ##
@ -270,13 +302,56 @@ method write*(s: BufferStream,
## If ``msglen < 0`` whole sequence ``sbytes`` will be writen to stream. ## If ``msglen < 0`` whole sequence ``sbytes`` will be writen to stream.
## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to ## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to
## stream. ## stream.
##
var buf: seq[byte] var buf: seq[byte]
shallowCopy(buf, if msglen > 0: msg[0..<msglen] else: msg) shallowCopy(buf, if msglen > 0: msg[0..<msglen] else: msg)
result = s.writeHandler(buf) if not isNil(s.writeHandler):
result = s.writeHandler(buf)
method close*(s: BufferStream) {.async, gcsafe.} = proc pipe*(s: BufferStream,
target: BufferStream): BufferStream =
## pipe the write end of this stream to
## be the source of the target stream
##
## Note that this only works with the LPStream
## interface methods `read*` and `write` are
## piped.
##
if s.isPiped:
raise newAlreadyPipedError()
s.isPiped = true
let oldHandler = target.writeHandler
proc handler(data: seq[byte]) {.async, closure.} =
if not isNil(oldHandler):
await oldHandler(data)
# if we're piping to self,
# then add the data to the
# buffer directly and fire
# the read event
if s == target:
for b in data:
s.readBuf.addLast(b)
# notify main loop of available
# data
s.dataReadEvent.fire()
else:
await target.pushTo(data)
s.writeHandler = handler
result = target
proc `|`*(s: BufferStream, target: BufferStream): BufferStream =
## pipe operator to make piping less verbose
pipe(s, target)
method close*(s: BufferStream) {.async.} =
## close the stream and clear the buffer ## close the stream and clear the buffer
for r in s.readReqs: for r in s.readReqs:
r.cancel() r.cancel()
s.dataReadEvent.fire()
s.readBuf.clear() s.readBuf.clear()
s.closed = true s.closeEvent.fire()
s.isClosed = true

View File

@ -26,40 +26,63 @@ proc newChronosStream*(server: StreamServer,
result.client = client result.client = client
result.reader = newAsyncStreamReader(client) result.reader = newAsyncStreamReader(client)
result.writer = newAsyncStreamWriter(client) result.writer = newAsyncStreamWriter(client)
result.closed = false result.closeEvent = newAsyncEvent()
method read*(s: ChronosStream, n = -1): Future[seq[byte]] {.async.} =
if s.reader.atEof:
raise newLPStreamClosedError()
method read*(s: ChronosStream, n = -1): Future[seq[byte]] {.async, gcsafe.} =
try: try:
result = await s.reader.read(n) result = await s.reader.read(n)
except AsyncStreamReadError as exc: except AsyncStreamReadError as exc:
raise newLPStreamReadError(exc.par) raise newLPStreamReadError(exc.par)
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method readExactly*(s: ChronosStream, method readExactly*(s: ChronosStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): Future[void] {.async, gcsafe.} = nbytes: int): Future[void] {.async.} =
if s.reader.atEof:
raise newLPStreamClosedError()
try: try:
await s.reader.readExactly(pbytes, nbytes) await s.reader.readExactly(pbytes, nbytes)
except AsyncStreamIncompleteError: except AsyncStreamIncompleteError:
raise newLPStreamIncompleteError() raise newLPStreamIncompleteError()
except AsyncStreamReadError as exc: except AsyncStreamReadError as exc:
raise newLPStreamReadError(exc.par) raise newLPStreamReadError(exc.par)
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method readLine*(s: ChronosStream, limit = 0, sep = "\r\n"): Future[string] {.async.} =
if s.reader.atEof:
raise newLPStreamClosedError()
method readLine*(s: ChronosStream, limit = 0, sep = "\r\n"): Future[string] {.async, gcsafe.} =
try: try:
result = await s.reader.readLine(limit, sep) result = await s.reader.readLine(limit, sep)
except AsyncStreamReadError as exc: except AsyncStreamReadError as exc:
raise newLPStreamReadError(exc.par) raise newLPStreamReadError(exc.par)
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} =
if s.reader.atEof:
raise newLPStreamClosedError()
method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async, gcsafe.} =
try: try:
result = await s.reader.readOnce(pbytes, nbytes) result = await s.reader.readOnce(pbytes, nbytes)
except AsyncStreamReadError as exc: except AsyncStreamReadError as exc:
raise newLPStreamReadError(exc.par) raise newLPStreamReadError(exc.par)
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method readUntil*(s: ChronosStream, method readUntil*(s: ChronosStream,
pbytes: pointer, pbytes: pointer,
nbytes: int, nbytes: int,
sep: seq[byte]): Future[int] {.async, gcsafe.} = sep: seq[byte]): Future[int] {.async.} =
if s.reader.atEof:
raise newLPStreamClosedError()
try: try:
result = await s.reader.readUntil(pbytes, nbytes, sep) result = await s.reader.readUntil(pbytes, nbytes, sep)
except AsyncStreamIncompleteError: except AsyncStreamIncompleteError:
@ -68,36 +91,62 @@ method readUntil*(s: ChronosStream,
raise newLPStreamLimitError() raise newLPStreamLimitError()
except LPStreamReadError as exc: except LPStreamReadError as exc:
raise newLPStreamReadError(exc.par) raise newLPStreamReadError(exc.par)
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method write*(s: ChronosStream, pbytes: pointer, nbytes: int) {.async.} =
if s.writer.atEof:
raise newLPStreamClosedError()
method write*(s: ChronosStream, pbytes: pointer, nbytes: int) {.async, gcsafe.} =
try: try:
await s.writer.write(pbytes, nbytes) await s.writer.write(pbytes, nbytes)
except AsyncStreamWriteError as exc: except AsyncStreamWriteError as exc:
raise newLPStreamWriteError(exc.par) raise newLPStreamWriteError(exc.par)
except AsyncStreamIncompleteError: except AsyncStreamIncompleteError:
raise newLPStreamIncompleteError() raise newLPStreamIncompleteError()
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method write*(s: ChronosStream, msg: string, msglen = -1) {.async.} =
if s.writer.atEof:
raise newLPStreamClosedError()
method write*(s: ChronosStream, msg: string, msglen = -1) {.async, gcsafe.} =
try: try:
await s.writer.write(msg, msglen) await s.writer.write(msg, msglen)
except AsyncStreamWriteError as exc: except AsyncStreamWriteError as exc:
raise newLPStreamWriteError(exc.par) raise newLPStreamWriteError(exc.par)
except AsyncStreamIncompleteError: except AsyncStreamIncompleteError:
raise newLPStreamIncompleteError() raise newLPStreamIncompleteError()
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method write*(s: ChronosStream, msg: seq[byte], msglen = -1) {.async.} =
if s.writer.atEof:
raise newLPStreamClosedError()
method write*(s: ChronosStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
try: try:
await s.writer.write(msg, msglen) await s.writer.write(msg, msglen)
except AsyncStreamWriteError as exc: except AsyncStreamWriteError as exc:
raise newLPStreamWriteError(exc.par) raise newLPStreamWriteError(exc.par)
except AsyncStreamIncompleteError: except AsyncStreamIncompleteError:
raise newLPStreamIncompleteError() raise newLPStreamIncompleteError()
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method close*(s: ChronosStream) {.async, gcsafe.} = method closed*(s: ChronosStream): bool {.inline.} =
# TODO: we might only need to check for reader's EOF
result = s.reader.atEof()
method close*(s: ChronosStream) {.async.} =
if not s.closed: if not s.closed:
trace "shutting down server", address = $s.client.remoteAddress() trace "shutting chronos stream", address = $s.client.remoteAddress()
await s.writer.finish() if not s.writer.closed():
await s.writer.closeWait() await s.writer.closeWait()
await s.reader.closeWait()
await s.client.closeWait() if not s.reader.closed():
s.closed = true await s.reader.closeWait()
if not s.client.closed():
await s.client.closeWait()
s.closeEvent.fire()

View File

@ -11,7 +11,8 @@ import chronos
type type
LPStream* = ref object of RootObj LPStream* = ref object of RootObj
closed*: bool isClosed*: bool
closeEvent*: AsyncEvent
LPStreamError* = object of CatchableError LPStreamError* = object of CatchableError
LPStreamIncompleteError* = object of LPStreamError LPStreamIncompleteError* = object of LPStreamError
@ -47,40 +48,43 @@ proc newLPStreamIncorrectError*(m: string): ref Exception {.inline.} =
proc newLPStreamClosedError*(): ref Exception {.inline.} = proc newLPStreamClosedError*(): ref Exception {.inline.} =
result = newException(LPStreamClosedError, "Stream closed!") result = newException(LPStreamClosedError, "Stream closed!")
method closed*(s: LPStream): bool {.base, inline.} =
s.isClosed
method read*(s: LPStream, n = -1): Future[seq[byte]] method read*(s: LPStream, n = -1): Future[seq[byte]]
{.base, async, gcsafe.} = {.base, async.} =
doAssert(false, "not implemented!") doAssert(false, "not implemented!")
method readExactly*(s: LPStream, pbytes: pointer, nbytes: int): Future[void] method readExactly*(s: LPStream, pbytes: pointer, nbytes: int): Future[void]
{.base, async, gcsafe.} = {.base, async.} =
doAssert(false, "not implemented!") doAssert(false, "not implemented!")
method readLine*(s: LPStream, limit = 0, sep = "\r\n"): Future[string] method readLine*(s: LPStream, limit = 0, sep = "\r\n"): Future[string]
{.base, async, gcsafe.} = {.base, async.} =
doAssert(false, "not implemented!") doAssert(false, "not implemented!")
method readOnce*(s: LPStream, pbytes: pointer, nbytes: int): Future[int] method readOnce*(s: LPStream, pbytes: pointer, nbytes: int): Future[int]
{.base, async, gcsafe.} = {.base, async.} =
doAssert(false, "not implemented!") doAssert(false, "not implemented!")
method readUntil*(s: LPStream, method readUntil*(s: LPStream,
pbytes: pointer, nbytes: int, pbytes: pointer, nbytes: int,
sep: seq[byte]): Future[int] sep: seq[byte]): Future[int]
{.base, async, gcsafe.} = {.base, async.} =
doAssert(false, "not implemented!") doAssert(false, "not implemented!")
method write*(s: LPStream, pbytes: pointer, nbytes: int) method write*(s: LPStream, pbytes: pointer, nbytes: int)
{.base, async, gcsafe.} = {.base, async.} =
doAssert(false, "not implemented!") doAssert(false, "not implemented!")
method write*(s: LPStream, msg: string, msglen = -1) method write*(s: LPStream, msg: string, msglen = -1)
{.base, async, gcsafe.} = {.base, async.} =
doAssert(false, "not implemented!") doAssert(false, "not implemented!")
method write*(s: LPStream, msg: seq[byte], msglen = -1) method write*(s: LPStream, msg: seq[byte], msglen = -1)
{.base, async, gcsafe.} = {.base, async.} =
doAssert(false, "not implemented!") doAssert(false, "not implemented!")
method close*(s: LPStream) method close*(s: LPStream)
{.base, async, gcsafe.} = {.base, async.} =
doAssert(false, "not implemented!") doAssert(false, "not implemented!")

View File

@ -11,13 +11,11 @@ import tables, sequtils, options, strformat
import chronos, chronicles import chronos, chronicles
import connection, import connection,
transports/transport, transports/transport,
stream/lpstream,
multistream, multistream,
protocols/protocol, protocols/protocol,
protocols/secure/secure, protocols/secure/secure,
protocols/secure/plaintext, # for plain text protocols/secure/plaintext, # for plain text
peerinfo, peerinfo,
multiaddress,
protocols/identify, protocols/identify,
protocols/pubsub/pubsub, protocols/pubsub/pubsub,
muxers/muxer, muxers/muxer,
@ -26,6 +24,12 @@ import connection,
logScope: logScope:
topic = "Switch" topic = "Switch"
#TODO: General note - use a finite state machine to manage the different
# steps of connections establishing and upgrading. This makes everything
# more robust and less prone to ordering attacks - i.e. muxing can come if
# and only if the channel has been secured (i.e. if a secure manager has been
# previously provided)
type type
NoPubSubException = object of CatchableError NoPubSubException = object of CatchableError
@ -48,7 +52,6 @@ proc newNoPubSubException(): ref Exception {.inline.} =
proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
## secure the incoming connection ## secure the incoming connection
# plaintext for now, doesn't do anything
let managers = toSeq(s.secureManagers.keys) let managers = toSeq(s.secureManagers.keys)
if managers.len == 0: if managers.len == 0:
raise newException(CatchableError, "No secure managers registered!") raise newException(CatchableError, "No secure managers registered!")
@ -62,12 +65,14 @@ proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
proc identify(s: Switch, conn: Connection): Future[PeerInfo] {.async, gcsafe.} = proc identify(s: Switch, conn: Connection): Future[PeerInfo] {.async, gcsafe.} =
## identify the connection ## identify the connection
result = conn.peerInfo
try: try:
if (await s.ms.select(conn, s.identity.codec)): if (await s.ms.select(conn, s.identity.codec)):
let info = await s.identity.identify(conn, conn.peerInfo) let info = await s.identity.identify(conn, conn.peerInfo)
if info.pubKey.isSome: if info.pubKey.isSome:
result.peerId = some(PeerID.init(info.pubKey.get())) # we might not have a peerId at all result.peerId = some(PeerID.init(info.pubKey.get())) # we might not have a peerId at all
trace "identify: identified remote peer", peer = result.id
if info.addrs.len > 0: if info.addrs.len > 0:
result.addrs = info.addrs result.addrs = info.addrs
@ -75,7 +80,6 @@ proc identify(s: Switch, conn: Connection): Future[PeerInfo] {.async, gcsafe.} =
if info.protos.len > 0: if info.protos.len > 0:
result.protocols = info.protos result.protocols = info.protos
trace "identify: identified remote peer ", peer = result.peerId.get().pretty
except IdentityInvalidMsgError as exc: except IdentityInvalidMsgError as exc:
error "identify: invalid message", msg = exc.msg error "identify: invalid message", msg = exc.msg
except IdentityNoMatchError as exc: except IdentityNoMatchError as exc:
@ -100,22 +104,23 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} =
muxer.streamHandler = s.streamHandler muxer.streamHandler = s.streamHandler
# new stream for identify # new stream for identify
let stream = await muxer.newStream() var stream = await muxer.newStream()
let handlerFut = muxer.handle() let handlerFut = muxer.handle()
# add muxer handler cleanup proc # add muxer handler cleanup proc
handlerFut.addCallback( handlerFut.addCallback(
proc(udata: pointer = nil) {.gcsafe.} = proc(udata: pointer = nil) {.gcsafe.} =
trace "mux: Muxer handler completed for peer ", trace "muxer handler completed for peer",
peer = conn.peerInfo.peerId.get().pretty peer = conn.peerInfo.id
) )
# do identify first, so that we have a # do identify first, so that we have a
# PeerInfo in case we didn't before # PeerInfo in case we didn't before
conn.peerInfo = await s.identify(stream) conn.peerInfo = await s.identify(stream)
await stream.close() # close idenity stream
trace "connection's peerInfo", peerInfo = conn.peerInfo.peerId await stream.close() # close identify stream
trace "connection's peerInfo", peerInfo = conn.peerInfo
# store it in muxed connections if we have a peer for it # store it in muxed connections if we have a peer for it
# TODO: We should make sure that this are cleaned up properly # TODO: We should make sure that this are cleaned up properly
@ -123,43 +128,42 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} =
# happen once secio is in place, but still something to keep # happen once secio is in place, but still something to keep
# in mind # in mind
if conn.peerInfo.peerId.isSome: if conn.peerInfo.peerId.isSome:
trace "adding muxer for peer", peer = conn.peerInfo.peerId.get().pretty trace "adding muxer for peer", peer = conn.peerInfo.id
s.muxed[conn.peerInfo.peerId.get().pretty] = muxer s.muxed[conn.peerInfo.id] = muxer
proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} =
if conn.peerInfo.peerId.isSome: if conn.peerInfo.peerId.isSome:
let id = conn.peerInfo.peerId.get().pretty let id = conn.peerInfo.id
if s.muxed.contains(id): trace "cleaning up connection for peer", peerId = id
await s.muxed[id].close if id in s.muxed:
await s.muxed[id].close()
s.muxed.del(id)
if s.connections.contains(id): if id in s.connections:
await s.connections[id].close() await s.connections[id].close()
s.connections.del(id)
proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Option[Connection]] {.async, gcsafe.} = proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Option[Connection]] {.async, gcsafe.} =
# if there is a muxer for the connection # if there is a muxer for the connection
# use it instead to create a muxed stream # use it instead to create a muxed stream
if s.muxed.contains(peerInfo.peerId.get().pretty): if peerInfo.id in s.muxed:
trace "connection is muxed, retriving muxer and setting up a stream" trace "connection is muxed, setting up a stream"
let muxer = s.muxed[peerInfo.peerId.get().pretty] let muxer = s.muxed[peerInfo.id]
let conn = await muxer.newStream() let conn = await muxer.newStream()
result = some(conn) result = some(conn)
proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
trace "handling connection", conn = conn trace "handling connection", conn = conn
result = conn result = conn
## perform upgrade flow
if result.peerInfo.peerId.isSome:
let id = result.peerInfo.peerId.get().pretty
if s.connections.contains(id):
# if we already have a connection for this peer,
# close the incoming connection and return the
# existing one
await result.close()
return s.connections[id]
s.connections[id] = result
result = await s.secure(conn) # secure the connection # don't mux/secure twise
if conn.peerInfo.peerId.isSome and
conn.peerInfo.id in s.muxed:
return
result = await s.secure(result) # secure the connection
await s.mux(result) # mux it if possible await s.mux(result) # mux it if possible
s.connections[conn.peerInfo.id] = result
proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
trace "upgrading incoming connection" trace "upgrading incoming connection"
@ -192,25 +196,35 @@ proc dial*(s: Switch,
peer: PeerInfo, peer: PeerInfo,
proto: string = ""): proto: string = ""):
Future[Connection] {.async.} = Future[Connection] {.async.} =
trace "dialing peer", peer = peer.peerId.get().pretty let id = peer.id
trace "dialing peer", peer = id
for t in s.transports: # for each transport for t in s.transports: # for each transport
for a in peer.addrs: # for each address for a in peer.addrs: # for each address
if t.handles(a): # check if it can dial it if t.handles(a): # check if it can dial it
result = await t.dial(a) if id notin s.connections:
# make sure to assign the peer to the connection trace "dialing address", address = $a
result.peerInfo = peer result = await t.dial(a)
# make sure to assign the peer to the connection
result.peerInfo = peer
result = await s.upgradeOutgoing(result) result = await s.upgradeOutgoing(result)
result.closeEvent.wait().addCallback(
proc(udata: pointer) =
asyncCheck s.cleanupConn(result)
)
let stream = await s.getMuxedStream(peer) if proto.len > 0 and not result.closed:
if stream.isSome: let stream = await s.getMuxedStream(peer)
trace "connection is muxed, return muxed stream" if stream.isSome:
result = stream.get() trace "connection is muxed, return muxed stream"
result = stream.get()
trace "attempting to select remote", proto = proto
trace "dial: attempting to select remote ", proto = proto if not (await s.ms.select(result, proto)):
if not (await s.ms.select(result, proto)): error "unable to select protocol: ", proto = proto
error "dial: Unable to select protocol: ", proto = proto raise newException(CatchableError,
raise newException(CatchableError, &"unable to select protocol: {proto}")
&"Unable to select protocol: {proto}")
break # don't dial more than one addr on the same transport
proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} = proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} =
if isNil(proto.handler): if isNil(proto.handler):
@ -224,10 +238,15 @@ proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} =
s.ms.addHandler(proto.codec, proto) s.ms.addHandler(proto.codec, proto)
proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
trace "starting switch"
proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} =
try: try:
await s.upgradeIncoming(conn) # perform upgrade on incoming connection await s.upgradeIncoming(conn) # perform upgrade on incoming connection
except CatchableError as exc:
trace "exception occured", exc = exc.msg
finally: finally:
await conn.close()
await s.cleanupConn(conn) await s.cleanupConn(conn)
var startFuts: seq[Future[void]] var startFuts: seq[Future[void]]
@ -237,10 +256,13 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
var server = await t.listen(a, handle) var server = await t.listen(a, handle)
s.peerInfo.addrs[i] = t.ma # update peer's address s.peerInfo.addrs[i] = t.ma # update peer's address
startFuts.add(server) startFuts.add(server)
result = startFuts # listen for incoming connections result = startFuts # listen for incoming connections
proc stop*(s: Switch) {.async.} = proc stop*(s: Switch) {.async.} =
await allFutures(toSeq(s.connections.values).mapIt(it.close())) trace "stopping switch"
await allFutures(toSeq(s.connections.values).mapIt(s.cleanupConn(it)))
await allFutures(s.transports.mapIt(it.close())) await allFutures(s.transports.mapIt(it.close()))
proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =

View File

@ -7,7 +7,7 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import chronos, chronicles import chronos, chronicles, sequtils
import transport, import transport,
../wire, ../wire,
../connection, ../connection,
@ -78,5 +78,5 @@ method dial*(t: TcpTransport,
result = await t.connHandler(t.server, client, true) result = await t.connHandler(t.server, client, true)
method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} = method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} =
## TODO: implement logic to properly discriminat TCP multiaddrs if procCall Transport(t).handles(address):
true result = address.protocols.filterIt( it == multiCodec("tcp") ).len > 0

View File

@ -9,8 +9,7 @@
import sequtils import sequtils
import chronos, chronicles import chronos, chronicles
import ../peerinfo, import ../connection,
../connection,
../multiaddress, ../multiaddress,
../multicodec ../multicodec
@ -62,9 +61,10 @@ method upgrade*(t: Transport) {.base, async, gcsafe.} =
method handles*(t: Transport, address: MultiAddress): bool {.base, gcsafe.} = method handles*(t: Transport, address: MultiAddress): bool {.base, gcsafe.} =
## check if transport supportes the multiaddress ## check if transport supportes the multiaddress
# TODO: this should implement generic logic that would use the multicodec
# declared in the multicodec field and set by each individual transport # by default we skip circuit addresses to avoid
discard # having to repeat the check in every transport
address.protocols.filterIt( it == multiCodec("p2p-circuit") ).len == 0
method localAddress*(t: Transport): MultiAddress {.base, gcsafe.} = method localAddress*(t: Transport): MultiAddress {.base, gcsafe.} =
## get the local address of the transport in case started with 0.0.0.0:0 ## get the local address of the transport in case started with 0.0.0.0:0

View File

@ -53,7 +53,17 @@ proc initVBuffer*(): VBuffer =
## Initialize empty VBuffer. ## Initialize empty VBuffer.
result.buffer = newSeqOfCap[byte](128) result.buffer = newSeqOfCap[byte](128)
proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) = proc writePBVarint*(vb: var VBuffer, value: PBSomeUVarint) =
## Write ``value`` as variable unsigned integer.
var length = 0
var v = value and cast[type(value)](0xFFFF_FFFF_FFFF_FFFF)
vb.buffer.setLen(len(vb.buffer) + vsizeof(v))
let res = PB.putUVarint(toOpenArray(vb.buffer, vb.offset, len(vb.buffer) - 1),
length, v)
doAssert(res == VarintStatus.Success)
vb.offset += length
proc writeLPVarint*(vb: var VBuffer, value: LPSomeUVarint) =
## Write ``value`` as variable unsigned integer. ## Write ``value`` as variable unsigned integer.
var length = 0 var length = 0
# LibP2P varint supports only 63 bits. # LibP2P varint supports only 63 bits.
@ -64,6 +74,9 @@ proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) =
doAssert(res == VarintStatus.Success) doAssert(res == VarintStatus.Success)
vb.offset += length vb.offset += length
proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) =
writeLPVarint(vb, value)
proc writeSeq*[T: byte|char](vb: var VBuffer, value: openarray[T]) = proc writeSeq*[T: byte|char](vb: var VBuffer, value: openarray[T]) =
## Write array ``value`` to buffer ``vb``, value will be prefixed with ## Write array ``value`` to buffer ``vb``, value will be prefixed with
## varint length of the array. ## varint length of the array.

View File

@ -1,4 +1,4 @@
import unittest, deques, sequtils, strformat import unittest, strformat
import chronos import chronos
import ../libp2p/stream/bufferstream import ../libp2p/stream/bufferstream
@ -220,7 +220,6 @@ suite "BufferStream":
test "reads should happen in order": test "reads should happen in order":
proc testWritePtr(): Future[bool] {.async.} = proc testWritePtr(): Future[bool] {.async.} =
var count = 1
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let buff = newBufferStream(writeHandler, 10) let buff = newBufferStream(writeHandler, 10)
check buff.len == 0 check buff.len == 0
@ -245,3 +244,199 @@ suite "BufferStream":
check: check:
waitFor(testWritePtr()) == true waitFor(testWritePtr()) == true
test "pipe two streams without the `pipe` or `|` helpers":
proc pipeTest(): Future[bool] {.async.} =
proc writeHandler1(data: seq[byte]) {.async, gcsafe.}
proc writeHandler2(data: seq[byte]) {.async, gcsafe.}
var buf1 = newBufferStream(writeHandler1)
var buf2 = newBufferStream(writeHandler2)
proc writeHandler1(data: seq[byte]) {.async, gcsafe.} =
var msg = cast[string](data)
check msg == "Hello!"
await buf2.pushTo(data)
proc writeHandler2(data: seq[byte]) {.async, gcsafe.} =
var msg = cast[string](data)
check msg == "Hello!"
await buf1.pushTo(data)
var res1: seq[byte] = newSeq[byte](7)
var readFut1 = buf1.readExactly(addr res1[0], 7)
var res2: seq[byte] = newSeq[byte](7)
var readFut2 = buf2.readExactly(addr res2[0], 7)
await buf1.pushTo(cast[seq[byte]]("Hello2!"))
await buf2.pushTo(cast[seq[byte]]("Hello1!"))
await allFutures(readFut1, readFut2)
check:
res1 == cast[seq[byte]]("Hello2!")
res2 == cast[seq[byte]]("Hello1!")
result = true
check:
waitFor(pipeTest()) == true
test "pipe A -> B":
proc pipeTest(): Future[bool] {.async.} =
var buf1 = newBufferStream()
var buf2 = buf1.pipe(newBufferStream())
var res1: seq[byte] = newSeq[byte](7)
var readFut = buf2.readExactly(addr res1[0], 7)
await buf1.write(cast[seq[byte]]("Hello1!"))
await readFut
check:
res1 == cast[seq[byte]]("Hello1!")
result = true
check:
waitFor(pipeTest()) == true
test "pipe A -> B and B -> A":
proc pipeTest(): Future[bool] {.async.} =
var buf1 = newBufferStream()
var buf2 = newBufferStream()
buf1 = buf1.pipe(buf2).pipe(buf1)
var res1: seq[byte] = newSeq[byte](7)
var readFut1 = buf1.readExactly(addr res1[0], 7)
var res2: seq[byte] = newSeq[byte](7)
var readFut2 = buf2.readExactly(addr res2[0], 7)
await buf1.write(cast[seq[byte]]("Hello1!"))
await buf2.write(cast[seq[byte]]("Hello2!"))
await allFutures(readFut1, readFut2)
check:
res1 == cast[seq[byte]]("Hello2!")
res2 == cast[seq[byte]]("Hello1!")
result = true
check:
waitFor(pipeTest()) == true
test "pipe A -> A (echo)":
proc pipeTest(): Future[bool] {.async.} =
var buf1 = newBufferStream()
buf1 = buf1.pipe(buf1)
proc reader(): Future[seq[byte]] = buf1.read(6)
proc writer(): Future[void] = buf1.write(cast[seq[byte]]("Hello!"))
var writerFut = writer()
var readerFut = reader()
await writerFut
check:
(await readerFut) == cast[seq[byte]]("Hello!")
result = true
check:
waitFor(pipeTest()) == true
test "pipe with `|` operator - A -> B":
proc pipeTest(): Future[bool] {.async.} =
var buf1 = newBufferStream()
var buf2 = buf1 | newBufferStream()
var res1: seq[byte] = newSeq[byte](7)
var readFut = buf2.readExactly(addr res1[0], 7)
await buf1.write(cast[seq[byte]]("Hello1!"))
await readFut
check:
res1 == cast[seq[byte]]("Hello1!")
result = true
check:
waitFor(pipeTest()) == true
test "pipe with `|` operator - A -> B and B -> A":
proc pipeTest(): Future[bool] {.async.} =
var buf1 = newBufferStream()
var buf2 = newBufferStream()
buf1 = buf1 | buf2 | buf1
var res1: seq[byte] = newSeq[byte](7)
var readFut1 = buf1.readExactly(addr res1[0], 7)
var res2: seq[byte] = newSeq[byte](7)
var readFut2 = buf2.readExactly(addr res2[0], 7)
await buf1.write(cast[seq[byte]]("Hello1!"))
await buf2.write(cast[seq[byte]]("Hello2!"))
await allFutures(readFut1, readFut2)
check:
res1 == cast[seq[byte]]("Hello2!")
res2 == cast[seq[byte]]("Hello1!")
result = true
check:
waitFor(pipeTest()) == true
test "pipe with `|` operator - A -> A (echo)":
proc pipeTest(): Future[bool] {.async.} =
var buf1 = newBufferStream()
buf1 = buf1 | buf1
proc reader(): Future[seq[byte]] = buf1.read(6)
proc writer(): Future[void] = buf1.write(cast[seq[byte]]("Hello!"))
var writerFut = writer()
var readerFut = reader()
await writerFut
check:
(await readerFut) == cast[seq[byte]]("Hello!")
result = true
check:
waitFor(pipeTest()) == true
# TODO: Need to implement deadlock prevention when
# piping to self
test "pipe deadlock":
proc pipeTest(): Future[bool] {.async.} =
var buf1 = newBufferStream(size = 5)
buf1 = buf1 | buf1
var count = 30000
proc reader() {.async.} =
while count > 0:
discard await buf1.read(7)
proc writer() {.async.} =
while count > 0:
await buf1.write(cast[seq[byte]]("Hello2!"))
count.dec
var writerFut = writer()
var readerFut = reader()
await allFutures(readerFut, writerFut)
result = true
check:
waitFor(pipeTest()) == true

View File

@ -274,25 +274,15 @@ suite "Mplex":
expect LPStreamClosedError: expect LPStreamClosedError:
waitFor(testClosedForWrite()) waitFor(testClosedForWrite())
test "half closed - channel should close for read": test "half closed - channel should close for read by remote":
proc testClosedForRead(): Future[void] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
await chann.closedByRemote()
asyncDiscard chann.read()
expect LPStreamClosedError:
waitFor(testClosedForRead())
test "half closed - channel should close for read after eof":
proc testClosedForRead(): Future[void] {.async.} = proc testClosedForRead(): Future[void] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
await chann.pushTo(cast[seq[byte]]("Hello!")) await chann.pushTo(cast[seq[byte]]("Hello!"))
await chann.close() await chann.closedByRemote()
let msg = await chann.read() discard await chann.read() # this should work, since there is data in the buffer
asyncDiscard chann.read() discard await chann.read() # this should throw
expect LPStreamClosedError: expect LPStreamClosedError:
waitFor(testClosedForRead()) waitFor(testClosedForRead())
@ -312,7 +302,7 @@ suite "Mplex":
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
await chann.reset() await chann.reset()
asyncDiscard chann.read() await chann.write(cast[seq[byte]]("Hello!"))
expect LPStreamClosedError: expect LPStreamClosedError:
waitFor(testResetWrite()) waitFor(testResetWrite())

View File

@ -1,4 +1,4 @@
import unittest, strutils, sequtils, sugar, strformat, options import unittest, strutils, sequtils, strformat, options
import chronos import chronos
import ../libp2p/connection, import ../libp2p/connection,
../libp2p/multistream, ../libp2p/multistream,
@ -51,7 +51,8 @@ method write*(s: TestSelectStream, msg: seq[byte], msglen = -1)
method write*(s: TestSelectStream, msg: string, msglen = -1) method write*(s: TestSelectStream, msg: string, msglen = -1)
{.async, gcsafe.} = discard {.async, gcsafe.} = discard
method close(s: TestSelectStream) {.async, gcsafe.} = s.closed = true method close(s: TestSelectStream) {.async, gcsafe.} =
s.isClosed = true
proc newTestSelectStream(): TestSelectStream = proc newTestSelectStream(): TestSelectStream =
new result new result
@ -97,7 +98,8 @@ method write*(s: TestLsStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
method write*(s: TestLsStream, msg: string, msglen = -1) method write*(s: TestLsStream, msg: string, msglen = -1)
{.async, gcsafe.} = discard {.async, gcsafe.} = discard
method close(s: TestLsStream) {.async, gcsafe.} = s.closed = true method close(s: TestLsStream) {.async, gcsafe.} =
s.isClosed = true
proc newTestLsStream(ls: LsHandler): TestLsStream {.gcsafe.} = proc newTestLsStream(ls: LsHandler): TestLsStream {.gcsafe.} =
new result new result
@ -143,7 +145,8 @@ method write*(s: TestNaStream, msg: string, msglen = -1) {.async, gcsafe.} =
if s.step == 4: if s.step == 4:
await s.na(msg) await s.na(msg)
method close(s: TestNaStream) {.async, gcsafe.} = s.closed = true method close(s: TestNaStream) {.async, gcsafe.} =
s.isClosed = true
proc newTestNaStream(na: NaHandler): TestNaStream = proc newTestNaStream(na: NaHandler): TestNaStream =
new result new result

View File

@ -2,5 +2,11 @@ import unittest
import testvarint, testbase32, testbase58, testbase64 import testvarint, testbase32, testbase58, testbase64
import testrsa, testecnist, tested25519, testsecp256k1, testcrypto import testrsa, testecnist, tested25519, testsecp256k1, testcrypto
import testmultibase, testmultihash, testmultiaddress, testcid, testpeer import testmultibase, testmultihash, testmultiaddress, testcid, testpeer
import testtransport, testmultistream, testbufferstream,
testmplex, testidentify, testswitch, testpubsub import testtransport,
testmultistream,
testbufferstream,
testidentify,
testswitch,
testpubsub,
testmplex

View File

@ -1,5 +1,5 @@
import unittest, tables, options import unittest, tables, options
import chronos, chronicles import chronos
import ../libp2p/[switch, import ../libp2p/[switch,
multistream, multistream,
protocols/identify, protocols/identify,
@ -36,7 +36,7 @@ method init(p: TestProto) {.gcsafe.} =
suite "Switch": suite "Switch":
test "e2e use switch": test "e2e use switch":
proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) = proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) {.gcsafe.}=
let seckey = PrivateKey.random(RSA) let seckey = PrivateKey.random(RSA)
var peerInfo: PeerInfo var peerInfo: PeerInfo
peerInfo.peerId = some(PeerID.init(seckey)) peerInfo.peerId = some(PeerID.init(seckey))
@ -50,7 +50,11 @@ suite "Switch":
let transports = @[Transport(newTransport(TcpTransport))] let transports = @[Transport(newTransport(TcpTransport))]
let muxers = [(MplexCodec, mplexProvider)].toTable() let muxers = [(MplexCodec, mplexProvider)].toTable()
let secureManagers = [(SecioCodec, Secure(newSecio(seckey)))].toTable() let secureManagers = [(SecioCodec, Secure(newSecio(seckey)))].toTable()
let switch = newSwitch(peerInfo, transports, identify, muxers, secureManagers) let switch = newSwitch(peerInfo,
transports,
identify,
muxers,
secureManagers)
result = (switch, peerInfo) result = (switch, peerInfo)
proc testSwitch(): Future[bool] {.async, gcsafe.} = proc testSwitch(): Future[bool] {.async, gcsafe.} =