Feat/conn cleanup (#41)

Backporting proper connection cleanup from #36 to align with latest chronos changes.

* add close event

* use proper varint encoding

* add proper channel cleanup in mplex

* add connection cleanup in secio

* tidy up

* add dollar operator

* fix tests

* don't close connections prematurely

* handle closing streams properly

* misc

* implement address filtering logic

* adding pipe tests

* don't use gcsafe if not needed

* misc

* proper connection cleanup and stream muxing

* re-enable pubsub tests
This commit is contained in:
Dmitriy Ryajov 2019-12-03 22:44:54 -06:00 committed by GitHub
parent 1df16bdbce
commit 903e79ede1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 765 additions and 294 deletions

View File

@ -7,7 +7,7 @@
## This file may not be copied, modified, or distributed except according to
## those terms.
import chronos, options, chronicles
import chronos, chronicles
import peerinfo,
multiaddress,
stream/lpstream,
@ -26,15 +26,28 @@ type
InvalidVarintException = object of LPStreamError
proc newInvalidVarintException*(): ref InvalidVarintException =
result = newException(InvalidVarintException, "unable to prase varint")
newException(InvalidVarintException, "unable to prase varint")
proc newConnection*(stream: LPStream): Connection =
## create a new Connection for the specified async reader/writer
new result
result.stream = stream
result.closeEvent = newAsyncEvent()
# bind stream's close event to connection's close
# to ensure correct close propagation
let this = result
if not isNil(result.stream.closeEvent):
result.stream.closeEvent.wait().
addCallback(
proc (udata: pointer) =
if not this.closed:
trace "closing this connection because wrapped stream closed"
asyncCheck this.close()
)
method read*(s: Connection, n = -1): Future[seq[byte]] {.gcsafe.} =
result = s.stream.read(n)
s.stream.read(n)
method readExactly*(s: Connection,
pbytes: pointer,
@ -44,13 +57,13 @@ method readExactly*(s: Connection,
method readLine*(s: Connection,
limit = 0,
sep = "\r\n"):
sep = "\r\n"):
Future[string] {.gcsafe.} =
s.stream.readLine(limit, sep)
method readOnce*(s: Connection,
pbytes: pointer,
nbytes: int):
nbytes: int):
Future[int] {.gcsafe.} =
s.stream.readOnce(pbytes, nbytes)
@ -61,15 +74,15 @@ method readUntil*(s: Connection,
Future[int] {.gcsafe.} =
s.stream.readUntil(pbytes, nbytes, sep)
method write*(s: Connection,
pbytes: pointer,
nbytes: int):
method write*(s: Connection,
pbytes: pointer,
nbytes: int):
Future[void] {.gcsafe.} =
s.stream.write(pbytes, nbytes)
method write*(s: Connection,
msg: string,
msglen = -1):
method write*(s: Connection,
msg: string,
msglen = -1):
Future[void] {.gcsafe.} =
s.stream.write(msg, msglen)
@ -79,9 +92,20 @@ method write*(s: Connection,
Future[void] {.gcsafe.} =
s.stream.write(msg, msglen)
method closed*(s: Connection): bool =
if isNil(s.stream):
return false
result = s.stream.closed
method close*(s: Connection) {.async, gcsafe.} =
await s.stream.close()
s.closed = true
trace "closing connection"
if not s.closed:
if not isNil(s.stream) and not s.stream.closed:
await s.stream.close()
s.closeEvent.fire()
s.isClosed = true
trace "connection closed", closed = s.closed
proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} =
## read lenght prefixed msg
@ -100,21 +124,23 @@ proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} =
raise newInvalidVarintException()
result.setLen(size)
if size > 0.uint:
trace "reading exact bytes from stream", size = size
await s.readExactly(addr result[0], int(size))
except LPStreamIncompleteError, LPStreamReadError:
trace "remote connection closed", exc = getCurrentExceptionMsg()
except LPStreamIncompleteError as exc:
trace "remote connection ended unexpectedly", exc = exc.msg
except LPStreamReadError as exc:
trace "couldn't read from stream", exc = exc.msg
proc writeLp*(s: Connection, msg: string | seq[byte]): Future[void] {.gcsafe.} =
## write lenght prefixed
var buf = initVBuffer()
buf.writeSeq(msg)
buf.finish()
result = s.write(buf.buffer)
s.write(buf.buffer)
method getObservedAddrs*(c: Connection): Future[MultiAddress] {.base, async, gcsafe.} =
## get resolved multiaddresses for the connection
result = c.observedAddrs
proc `$`*(conn: Connection): string =
if conn.peerInfo.peerId.isSome:
result = $(conn.peerInfo.peerId.get())
result = $(conn.peerInfo)

View File

@ -855,7 +855,7 @@ proc connect*(api: DaemonAPI, peer: PeerID,
timeout))
pb.withMessage() do:
discard
finally:
except:
await api.closeConnection(transp)
proc disconnect*(api: DaemonAPI, peer: PeerID) {.async.} =

View File

@ -7,12 +7,12 @@
## This file may not be copied, modified, or distributed except according to
## those terms.
import sequtils, strutils, strformat
import strutils
import chronos, chronicles
import connection,
varint,
vbuffer,
protocols/protocol
protocols/protocol,
stream/lpstream
logScope:
topic = "Multistream"
@ -56,16 +56,16 @@ proc select*(m: MultisteamSelect,
trace "selecting proto", proto = proto
await conn.writeLp((proto[0] & "\n")) # select proto
result = cast[string](await conn.readLp()) # read ms header
result = cast[string]((await conn.readLp())) # read ms header
result.removeSuffix("\n")
if result != Codec:
trace "handshake failed", codec = result
trace "handshake failed", codec = result.toHex()
return ""
if proto.len() == 0: # no protocols, must be a handshake call
return
result = cast[string](await conn.readLp()) # read the first proto
result = cast[string]((await conn.readLp())) # read the first proto
trace "reading first requested proto"
result.removeSuffix("\n")
if result == proto[0]:
@ -76,7 +76,7 @@ proc select*(m: MultisteamSelect,
trace "selecting one of several protos"
for p in proto[1..<proto.len()]:
await conn.writeLp((p & "\n")) # select proto
result = cast[string](await conn.readLp()) # read the first proto
result = cast[string]((await conn.readLp())) # read the first proto
result.removeSuffix("\n")
if result == p:
trace "selected protocol", protocol = result
@ -102,7 +102,7 @@ proc list*(m: MultisteamSelect,
await conn.write(m.ls) # send ls
var list = newSeq[string]()
let ms = cast[string](await conn.readLp())
let ms = cast[string]((await conn.readLp()))
for s in ms.split("\n"):
if s.len() > 0:
list.add(s)
@ -111,8 +111,10 @@ proc list*(m: MultisteamSelect,
proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} =
trace "handle: starting multistream handling"
while not conn.closed:
var ms = cast[string](await conn.readLp())
try:
while not conn.closed:
await sleepAsync(1.millis)
var ms = cast[string]((await conn.readLp()))
ms.removeSuffix("\n")
trace "handle: got request for ", ms
@ -142,11 +144,15 @@ proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} =
try:
await h.protocol.handler(conn, ms)
return
except Exception as exc:
warn "exception while handling ", msg = exc.msg
except CatchableError as exc:
warn "exception while handling", msg = exc.msg
return
warn "no handlers for ", protocol = ms
await conn.write(m.na)
except CatchableError as exc:
trace "exception occured", exc = exc.msg
finally:
trace "leaving multistream loop"
proc addHandler*[T: LPProtocol](m: MultisteamSelect,
codec: string,

View File

@ -7,7 +7,7 @@
## This file may not be copied, modified, or distributed except according to
## those terms.
import chronos, options, sequtils, strformat
import chronos, options
import nimcrypto/utils, chronicles
import types,
../../connection,
@ -29,31 +29,33 @@ proc readMplexVarint(conn: Connection): Future[Option[uint]] {.async, gcsafe.} =
varint: uint
length: int
res: VarintStatus
var buffer = newSeq[byte](10)
buffer = newSeq[byte](10)
result = none(uint)
try:
for i in 0..<len(buffer):
await conn.readExactly(addr buffer[i], 1)
res = LP.getUVarint(buffer.toOpenArray(0, i), length, varint)
if res == VarintStatus.Success:
return some(varint)
if not conn.closed:
await conn.readExactly(addr buffer[i], 1)
res = PB.getUVarint(buffer.toOpenArray(0, i), length, varint)
if res == VarintStatus.Success:
return some(varint)
if res != VarintStatus.Success:
raise newInvalidVarintException()
except LPStreamIncompleteError:
trace "unable to read varint", exc = getCurrentExceptionMsg()
except LPStreamIncompleteError as exc:
trace "unable to read varint", exc = exc.msg
proc readMsg*(conn: Connection): Future[Option[Msg]] {.async, gcsafe.} =
let headerVarint = await conn.readMplexVarint()
if headerVarint.isNone:
return
trace "readMsg: read header varint ", varint = headerVarint
trace "read header varint", varint = headerVarint
let dataLenVarint = await conn.readMplexVarint()
var data: seq[byte]
if dataLenVarint.isSome and dataLenVarint.get() > 0.uint:
trace "readMsg: read size varint ", varint = dataLenVarint
data = await conn.read(dataLenVarint.get().int)
trace "read size varint", varint = dataLenVarint
let header = headerVarint.get()
result = some((header shr 3, MessageType(header and 0x7), data))
@ -64,11 +66,13 @@ proc writeMsg*(conn: Connection,
data: seq[byte] = @[]) {.async, gcsafe.} =
## write lenght prefixed
var buf = initVBuffer()
let header = (id shl 3 or ord(msgType).uint)
buf.writeVarint(id shl 3 or ord(msgType).uint)
buf.writeVarint(data.len().uint) # size should be always sent
buf.writePBVarint(id shl 3 or ord(msgType).uint)
buf.writePBVarint(data.len().uint) # size should be always sent
buf.finish()
await conn.write(buf.buffer & data)
try:
await conn.write(buf.buffer & data)
except LPStreamIncompleteError as exc:
trace "unable to send message", exc = exc.msg
proc writeMsg*(conn: Connection,
id: uint,

View File

@ -7,7 +7,6 @@
## This file may not be copied, modified, or distributed except according to
## those terms.
import strformat
import chronos, chronicles
import types,
coder,
@ -52,99 +51,110 @@ proc newChannel*(id: uint,
result.asyncLock = newAsyncLock()
let chan = result
proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} =
proc writeHandler(data: seq[byte]): Future[void] {.async.} =
# writes should happen in sequence
await chan.asyncLock.acquire()
trace "writeHandler: sending data ", data = data.toHex(), id = chan.id
trace "sending data ", data = data.toHex(),
id = chan.id,
initiator = chan.initiator
await conn.writeMsg(chan.id, chan.msgCode, data) # write header
chan.asyncLock.release()
result.initBufferStream(writeHandler, size)
proc closeMessage(s: LPChannel) {.async, gcsafe.} =
proc closeMessage(s: LPChannel) {.async.} =
await s.conn.writeMsg(s.id, s.closeCode) # write header
proc closed*(s: LPChannel): bool =
s.closedLocal and s.closedLocal
proc closedByRemote*(s: LPChannel) {.async.} =
s.closedRemote = true
proc cleanUp*(s: LPChannel): Future[void] =
# method which calls the underlying buffer's `close`
# method used instead of `close` since it's overloaded to
# simulate half-closed streams
result = procCall close(BufferStream(s))
proc open*(s: LPChannel): Future[void] =
s.conn.writeMsg(s.id, MessageType.New, s.name)
method close*(s: LPChannel) {.async, gcsafe.} =
s.closedLocal = true
await s.closeMessage()
proc resetMessage(s: LPChannel) {.async, gcsafe.} =
proc resetMessage(s: LPChannel) {.async.} =
await s.conn.writeMsg(s.id, s.resetCode)
proc resetByRemote*(s: LPChannel) {.async, gcsafe.} =
proc resetByRemote*(s: LPChannel) {.async.} =
await allFutures(s.close(), s.closedByRemote())
s.isReset = true
proc reset*(s: LPChannel) {.async.} =
await allFutures(s.resetMessage(), s.resetByRemote())
proc isReadEof(s: LPChannel): bool =
bool((s.closedRemote or s.closedLocal) and s.len() < 1)
method closed*(s: LPChannel): bool =
result = s.closedRemote and s.len == 0
proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] {.gcsafe.} =
if s.closedRemote:
proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] =
if s.closedRemote or s.isReset:
raise newLPStreamClosedError()
trace "pushing data to channel", data = data.toHex(),
id = s.id,
initiator = s.initiator
result = procCall pushTo(BufferStream(s), data)
method read*(s: LPChannel, n = -1): Future[seq[byte]] {.gcsafe.} =
if s.isReadEof():
method read*(s: LPChannel, n = -1): Future[seq[byte]] =
if s.closed or s.isReset:
raise newLPStreamClosedError()
result = procCall read(BufferStream(s), n)
method readExactly*(s: LPChannel,
pbytes: pointer,
nbytes: int):
Future[void] {.gcsafe.} =
if s.isReadEof():
method readExactly*(s: LPChannel,
pbytes: pointer,
nbytes: int):
Future[void] =
if s.closed or s.isReset:
raise newLPStreamClosedError()
result = procCall readExactly(BufferStream(s), pbytes, nbytes)
method readLine*(s: LPChannel,
limit = 0,
sep = "\r\n"):
Future[string] {.gcsafe.} =
if s.isReadEof():
Future[string] =
if s.closed or s.isReset:
raise newLPStreamClosedError()
result = procCall readLine(BufferStream(s), limit, sep)
method readOnce*(s: LPChannel,
pbytes: pointer,
nbytes: int):
Future[int] {.gcsafe.} =
if s.isReadEof():
Future[int] =
if s.closed or s.isReset:
raise newLPStreamClosedError()
result = procCall readOnce(BufferStream(s), pbytes, nbytes)
method readUntil*(s: LPChannel,
pbytes: pointer, nbytes: int,
sep: seq[byte]):
Future[int] {.gcsafe.} =
if s.isReadEof():
Future[int] =
if s.closed or s.isReset:
raise newLPStreamClosedError()
result = procCall readOnce(BufferStream(s), pbytes, nbytes)
method write*(s: LPChannel,
pbytes: pointer,
nbytes: int): Future[void] {.gcsafe.} =
if s.closedLocal:
nbytes: int): Future[void] =
if s.closedLocal or s.isReset:
raise newLPStreamClosedError()
result = procCall write(BufferStream(s), pbytes, nbytes)
method write*(s: LPChannel, msg: string, msglen = -1) {.async, gcsafe.} =
if s.closedLocal:
method write*(s: LPChannel, msg: string, msglen = -1) {.async.} =
if s.closedLocal or s.isReset:
raise newLPStreamClosedError()
result = procCall write(BufferStream(s), msg, msglen)
method write*(s: LPChannel, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
if s.closedLocal:
method write*(s: LPChannel, msg: seq[byte], msglen = -1) {.async.} =
if s.closedLocal or s.isReset:
raise newLPStreamClosedError()
result = procCall write(BufferStream(s), msg, msglen)

View File

@ -11,16 +11,14 @@
## Timeouts and message limits are still missing
## they need to be added ASAP
import tables, sequtils, options, strformat
import tables, sequtils, options
import chronos, chronicles
import coder, types, lpchannel,
../muxer,
../../varint,
import ../muxer,
../../connection,
../../vbuffer,
../../protocols/protocol,
../../stream/bufferstream,
../../stream/lpstream
../../stream/lpstream,
coder,
types,
lpchannel
logScope:
topic = "Mplex"
@ -34,9 +32,11 @@ type
proc getChannelList(m: Mplex, initiator: bool): var Table[uint, LPChannel] =
if initiator:
result = m.remote
else:
trace "picking local channels", initiator = initiator
result = m.local
else:
trace "picking remote channels", initiator = initiator
result = m.remote
proc newStreamInternal*(m: Mplex,
initiator: bool = true,
@ -45,17 +45,28 @@ proc newStreamInternal*(m: Mplex,
Future[LPChannel] {.async, gcsafe.} =
## create new channel/stream
let id = if initiator: m.currentId.inc(); m.currentId else: chanId
trace "creating new channel", channelId = id, initiator = initiator
result = newChannel(id, m.connection, initiator, name)
m.getChannelList(initiator)[id] = result
proc cleanupChann(m: Mplex, chann: LPChannel, initiator: bool) {.async, inline.} =
## call the channel's `close` to signal the
## remote that the channel is closing
if not isNil(chann) and not chann.closed:
await chann.close()
await chann.cleanUp()
m.getChannelList(initiator).del(chann.id)
trace "cleaned up channel", id = chann.id
method handle*(m: Mplex) {.async, gcsafe.} =
trace "starting mplex main loop"
try:
while not m.connection.closed:
trace "waiting for data"
let msg = await m.connection.readMsg()
if msg.isNone:
# TODO: allow poll with timeout to avoid using `sleepAsync`
await sleepAsync(10.millis)
await sleepAsync(1.millis)
continue
let (id, msgType, data) = msg.get()
@ -63,8 +74,11 @@ method handle*(m: Mplex) {.async, gcsafe.} =
var channel: LPChannel
if MessageType(msgType) != MessageType.New:
let channels = m.getChannelList(initiator)
if not channels.contains(id):
trace "handle: Channel with id and msg type ", id = id, msg = msgType
if id notin channels:
trace "Channel not found, skipping", id = id,
initiator = initiator,
msg = msgType
await sleepAsync(1.millis)
continue
channel = channels[id]
@ -72,36 +86,44 @@ method handle*(m: Mplex) {.async, gcsafe.} =
of MessageType.New:
let name = cast[string](data)
channel = await m.newStreamInternal(false, id, name)
trace "handle: created channel ", id = id, name = name
trace "created channel", id = id, name = name, inititator = true
if not isNil(m.streamHandler):
let stream = newConnection(channel)
stream.peerInfo = m.connection.peerInfo
let handlerFut = m.streamHandler(stream)
# channel cleanup routine
proc cleanUpChan(udata: pointer) {.gcsafe.} =
if handlerFut.finished:
channel.close().addCallback(
proc(udata: pointer) =
channel.cleanUp()
.addCallback(proc(udata: pointer) =
trace "handle: cleaned up channel ", id = id))
handlerFut.addCallback(cleanUpChan)
# cleanup channel once handler is finished
# stream.closeEvent.wait().addCallback(
# proc(udata: pointer) =
# asyncCheck cleanupChann(m, channel, initiator))
asyncCheck m.streamHandler(stream)
continue
of MessageType.MsgIn, MessageType.MsgOut:
trace "handle: pushing data to channel ", id = id, msgType = msgType
trace "pushing data to channel", id = id,
initiator = initiator,
msgType = msgType
await channel.pushTo(data)
of MessageType.CloseIn, MessageType.CloseOut:
trace "handle: closing channel ", id = id, msgType = msgType
trace "closing channel", id = id,
initiator = initiator,
msgType = msgType
await channel.closedByRemote()
m.getChannelList(initiator).del(id)
of MessageType.ResetIn, MessageType.ResetOut:
trace "handle: resetting channel ", id = id
trace "resetting channel", id = id,
initiator = initiator,
msgType = msgType
await channel.resetByRemote()
m.getChannelList(initiator).del(id)
break
except:
error "exception occurred", exception = getCurrentExceptionMsg()
except CatchableError as exc:
trace "exception occurred", exception = exc.msg
finally:
trace "stopping mplex main loop"
await m.connection.close()
proc newMplex*(conn: Connection,
@ -112,13 +134,20 @@ proc newMplex*(conn: Connection,
result.remote = initTable[uint, LPChannel]()
result.local = initTable[uint, LPChannel]()
let m = result
conn.closeEvent.wait().addCallback(
proc(udata: pointer) =
asyncCheck m.close()
)
method newStream*(m: Mplex, name: string = ""): Future[Connection] {.async, gcsafe.} =
let channel = await m.newStreamInternal()
await m.connection.writeMsg(channel.id, MessageType.New, name)
# TODO: open the channel (this should be lazy)
await channel.open()
result = newConnection(channel)
result.peerInfo = m.connection.peerInfo
method close*(m: Mplex) {.async, gcsafe.} =
await allFutures(@[allFutures(toSeq(m.remote.values).mapIt(it.close())),
allFutures(toSeq(m.local.values).mapIt(it.close()))])
m.connection.reset()
trace "closing mplex muxer"
await allFutures(@[allFutures(toSeq(m.remote.values).mapIt(it.reset())),
allFutures(toSeq(m.local.values).mapIt(it.reset()))])

View File

@ -8,7 +8,6 @@
## those terms.
import chronos
import ../../connection
const MaxMsgSize* = 1 shl 20 # 1mb
const MaxChannels* = 1000

View File

@ -10,7 +10,27 @@
import options
import peer, multiaddress
type PeerInfo* = object of RootObj
peerId*: Option[PeerID]
addrs*: seq[MultiAddress]
protocols*: seq[string]
type
PeerInfo* = object of RootObj
peerId*: Option[PeerID]
addrs*: seq[MultiAddress]
protocols*: seq[string]
proc id*(p: PeerInfo): string =
if p.peerId.isSome:
result = p.peerId.get().pretty
proc `$`*(p: PeerInfo): string =
if p.peerId.isSome:
result.add("PeerID: ")
result.add(p.id & "\n")
if p.addrs.len > 0:
result.add("Peer Addrs: ")
for a in p.addrs:
result.add($a & "\n")
if p.protocols.len > 0:
result.add("Protocols: ")
for proto in p.protocols:
result.add(proto & "\n")

View File

@ -7,7 +7,7 @@
## This file may not be copied, modified, or distributed except according to
## those terms.
import options, strformat
import options
import chronos, chronicles
import ../protobuf/minprotobuf,
../peerinfo,
@ -115,14 +115,14 @@ method init*(p: Identify) =
trace "handling identify request"
var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs())
await conn.writeLp(pb.buffer)
# await conn.close() #TODO: investigate why this breaks
p.handler = handle
p.codec = IdentifyCodec
proc identify*(p: Identify,
conn: Connection,
remotePeerInfo: PeerInfo):
Future[IdentifyInfo] {.async.} =
remotePeerInfo: PeerInfo): Future[IdentifyInfo] {.async, gcsafe.} =
var message = await conn.readLp()
if len(message) == 0:
trace "identify: Invalid or empty message received!"
@ -139,7 +139,7 @@ proc identify*(p: Identify,
if peer != remotePeerInfo.peerId.get():
trace "Peer ids don't match",
remote = peer.pretty(),
local = remotePeerInfo.peerId.get().pretty()
local = remotePeerInfo.id
raise newException(IdentityNoMatchError,
"Peer ids don't match")
@ -149,5 +149,4 @@ proc identify*(p: Identify,
proc push*(p: Identify, conn: Connection) {.async.} =
await conn.write(IdentifyPushCodec)
var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs())
let length = pb.getLen()
await conn.writeLp(pb.buffer)

View File

@ -8,9 +8,7 @@
## those terms.
import chronos
import ../connection,
../peerinfo,
../multiaddress
import ../connection
type
LPProtoHandler* = proc (conn: Connection,

View File

@ -14,6 +14,7 @@ import rpcmsg,
../../peer,
../../peerinfo,
../../connection,
../../stream/lpstream,
../../crypto/crypto,
../../protobuf/minprotobuf
@ -45,7 +46,7 @@ proc handle*(p: PubSubPeer) {.async, gcsafe.} =
trace "Decoded msg from peer", peer = p.id, msg = msg
await p.handler(p, @[msg])
except:
error "An exception occured while processing pubsub rpc requests", exc = getCurrentExceptionMsg()
trace "An exception occured while processing pubsub rpc requests", exc = getCurrentExceptionMsg()
finally:
trace "closing connection to pubsub peer", peer = p.id
await p.conn.close()

View File

@ -8,8 +8,7 @@
## those terms.
import chronos
import secure,
../../connection
import secure, ../../connection
const PlainTextCodec* = "/plaintext/1.0.0"

View File

@ -6,10 +6,12 @@
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
import options
import chronos, chronicles
import nimcrypto/[sysrand, hmac, sha2, sha, hash, rijndael, twofish, bcmode]
import secure,
../../connection,
../../stream/lpstream,
../../crypto/crypto,
../../crypto/ecnist,
../../protobuf/minprotobuf,
@ -60,7 +62,6 @@ type
ctxsha1: HMAC[sha1]
SecureConnection* = ref object of Connection
conn*: Connection
writerMac: SecureMac
readerMac: SecureMac
writerCoder: SecureCipher
@ -176,13 +177,13 @@ proc readMessage*(sconn: SecureConnection): Future[seq[byte]] {.async.} =
## Read message from channel secure connection ``sconn``.
try:
var buf = newSeq[byte](4)
await sconn.conn.readExactly(addr buf[0], 4)
await sconn.readExactly(addr buf[0], 4)
let length = (int(buf[0]) shl 24) or (int(buf[1]) shl 16) or
(int(buf[2]) shl 8) or (int(buf[3]))
trace "Recieved message header", header = toHex(buf), length = length
if length <= SecioMaxMessageSize:
buf.setLen(length)
await sconn.conn.readExactly(addr buf[0], length)
await sconn.readExactly(addr buf[0], length)
trace "Received message body", length = length,
buffer = toHex(buf)
if sconn.macCheckAndDecode(buf):
@ -213,21 +214,27 @@ proc writeMessage*(sconn: SecureConnection, message: seq[byte]) {.async.} =
msg[3] = byte(length and 0xFF)
trace "Writing message", message = toHex(msg)
try:
await sconn.conn.write(msg)
await sconn.write(msg)
except AsyncStreamWriteError:
trace "Could not write to connection"
proc newSecureConnection*(conn: Connection, hash: string, cipher: string,
proc newSecureConnection*(conn: Connection,
hash: string,
cipher: string,
secrets: Secret,
order: int): SecureConnection =
order: int,
peerId: PeerID): SecureConnection =
## Create new secure connection, using specified hash algorithm ``hash``,
## cipher algorithm ``cipher``, stretched keys ``secrets`` and order
## ``order``.
new result
result.stream = conn
result.closeEvent = newAsyncEvent()
let i0 = if order < 0: 1 else: 0
let i1 = if order < 0: 0 else: 1
result.conn = conn
trace "Writer credentials", mackey = toHex(secrets.macOpenArray(i0)),
enckey = toHex(secrets.keyOpenArray(i0)),
iv = toHex(secrets.ivOpenArray(i0))
@ -241,6 +248,8 @@ proc newSecureConnection*(conn: Connection, hash: string, cipher: string,
result.readerCoder.init(cipher, secrets.keyOpenArray(i1),
secrets.ivOpenArray(i1))
result.peerInfo.peerId = some(peerId)
proc transactMessage(conn: Connection,
msg: seq[byte]): Future[seq[byte]] {.async.} =
var buf = newSeq[byte](4)
@ -281,7 +290,6 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.}
remoteHashes: string
remotePeerId: PeerID
localPeerId: PeerID
ekey: PrivateKey
localBytesPubkey = s.localPublicKey.getBytes()
if randomBytes(localNonce) != SecioNonceSize:
@ -388,7 +396,8 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.}
# Perform Nonce exchange over encrypted channel.
result = newSecureConnection(conn, hash, cipher, keys, order)
result = newSecureConnection(conn, hash, cipher, keys, order, remotePeerId)
await result.writeMessage(remoteNonce)
var res = await result.readMessage()
@ -400,17 +409,21 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.}
trace "Secure handshake succeeded"
proc readLoop(sconn: SecureConnection, stream: BufferStream) {.async.} =
while not sconn.conn.closed:
try:
try:
while not sconn.closed:
let msg = await sconn.readMessage()
await stream.pushTo(msg)
except CatchableError as exc:
trace "exception in secio", exc = exc.msg
return
finally:
trace "ending secio readLoop"
if msg.len > 0:
await stream.pushTo(msg)
# tight loop, give a chance for other
# stuff to run as well
await sleepAsync(1.millis)
except CatchableError as exc:
trace "exception occured", exc = exc.msg
finally:
trace "ending secio readLoop", isclosed = sconn.closed()
proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async.} =
proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async, gcsafe.} =
var sconn = await s.handshake(conn)
proc writeHandler(data: seq[byte]) {.async, gcsafe.} =
trace "sending encrypted bytes", bytes = data.toHex()
@ -419,7 +432,13 @@ proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async.} =
var stream = newBufferStream(writeHandler)
asyncCheck readLoop(sconn, stream)
var secured = newConnection(stream)
secured.peerInfo = sconn.conn.peerInfo
secured.closeEvent.wait()
.addCallback(proc(udata: pointer) =
trace "wrapped connection closed, closing upstream"
if not sconn.closed:
asyncCheck sconn.close()
)
secured.peerInfo.peerId = sconn.peerInfo.peerId
result = secured
method init(s: Secio) {.gcsafe.} =

View File

@ -8,7 +8,7 @@
## those terms.
## This module implements an asynchronous buffer stream
## which emulates physical async IO.
## which emulates physical async IO.
##
## The stream is based on the standard library's `Deque`,
## which is itself based on a ring buffer.
@ -25,12 +25,12 @@
## ordered and asynchronous. Reads are queued up in order
## and are suspended when not enough data available. This
## allows preserving backpressure while maintaining full
## asynchrony. Both writting to the internal buffer with
## asynchrony. Both writting to the internal buffer with
## ``pushTo`` as well as reading with ``read*` methods,
## will suspend until either the amount of elements in the
## buffer goes below ``maxSize`` or more data becomes available.
import deques, tables, sequtils, math
import deques, math
import chronos
import ../stream/lpstream
@ -38,33 +38,49 @@ const DefaultBufferSize* = 1024
type
# TODO: figure out how to make this generic to avoid casts
WriteHandler* = proc (data: seq[byte]): Future[void] {.gcsafe.}
WriteHandler* = proc (data: seq[byte]): Future[void]
BufferStream* = ref object of LPStream
maxSize*: int # buffer's max size in bytes
readBuf: Deque[byte] # a deque is based on a ring buffer
readBuf: Deque[byte] # this is a ring buffer based dequeue, this makes it perfect as the backing store here
readReqs: Deque[Future[void]] # use dequeue to fire reads in order
dataReadEvent: AsyncEvent
writeHandler*: WriteHandler
lock: AsyncLock
isPiped: bool
proc requestReadBytes(s: BufferStream): Future[void] =
AlreadyPipedError* = object of CatchableError
NotWritableError* = object of CatchableError
proc newAlreadyPipedError*(): ref Exception {.inline.} =
result = newException(AlreadyPipedError, "stream already piped")
proc newNotWritableError*(): ref Exception {.inline.} =
result = newException(NotWritableError, "stream is not writable")
proc requestReadBytes(s: BufferStream): Future[void] =
## create a future that will complete when more
## data becomes available in the read buffer
result = newFuture[void]()
s.readReqs.addLast(result)
proc initBufferStream*(s: BufferStream, handler: WriteHandler, size: int = DefaultBufferSize) =
proc initBufferStream*(s: BufferStream,
handler: WriteHandler = nil,
size: int = DefaultBufferSize) =
s.maxSize = if isPowerOfTwo(size): size else: nextPowerOfTwo(size)
s.readBuf = initDeque[byte](s.maxSize)
s.readReqs = initDeque[Future[void]]()
s.dataReadEvent = newAsyncEvent()
s.lock = newAsyncLock()
s.writeHandler = handler
s.closeEvent = newAsyncEvent()
proc newBufferStream*(handler: WriteHandler, size: int = DefaultBufferSize): BufferStream =
proc newBufferStream*(handler: WriteHandler = nil,
size: int = DefaultBufferSize): BufferStream =
new result
result.initBufferStream(handler, size)
proc popFirst*(s: BufferStream): byte =
proc popFirst*(s: BufferStream): byte =
result = s.readBuf.popFirst()
s.dataReadEvent.fire()
@ -78,15 +94,24 @@ proc shrink(s: BufferStream, fromFirst = 0, fromLast = 0) =
proc len*(s: BufferStream): int = s.readBuf.len
proc pushTo*(s: BufferStream, data: seq[byte]) {.async, gcsafe.} =
proc pushTo*(s: BufferStream, data: seq[byte]) {.async.} =
## Write bytes to internal read buffer, use this to fill up the
## buffer with data.
##
## This method is async and will wait until all data has been
## written to the internal buffer; this is done so that backpressure
## is preserved.
##
await s.lock.acquire()
var index = 0
while true:
# give readers a chance free up the buffer
# it it's full.
if s.readBuf.len >= s.maxSize:
await sleepAsync(10.millis)
while index < data.len and s.readBuf.len < s.maxSize:
s.readBuf.addLast(data[index])
inc(index)
@ -94,18 +119,20 @@ proc pushTo*(s: BufferStream, data: seq[byte]) {.async, gcsafe.} =
# resolve the next queued read request
if s.readReqs.len > 0:
s.readReqs.popFirst().complete()
if index >= data.len:
break
# if we couldn't transfer all the data to the
# internal buf wait on a read event
await s.dataReadEvent.wait()
s.lock.release()
method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async, gcsafe.} =
method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async.} =
## Read all bytes (n <= 0) or exactly `n` bytes from buffer
##
## This procedure allocates buffer seq[byte] and return it as result.
##
var size = if n > 0: n else: s.readBuf.len()
var index = 0
while index < size:
@ -116,25 +143,26 @@ method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async, gcsafe.} =
if index < size:
await s.requestReadBytes()
method readExactly*(s: BufferStream,
pbytes: pointer,
nbytes: int):
Future[void] {.async, gcsafe.} =
method readExactly*(s: BufferStream,
pbytes: pointer,
nbytes: int):
Future[void] {.async.} =
## Read exactly ``nbytes`` bytes from read-only stream ``rstream`` and store
## it to ``pbytes``.
##
## If EOF is received and ``nbytes`` is not yet read, the procedure
## will raise ``LPStreamIncompleteError``.
let buff = await s.read(nbytes)
##
var buff = await s.read(nbytes)
if nbytes > buff.len():
raise newLPStreamIncompleteError()
copyMem(pbytes, unsafeAddr buff[0], nbytes)
copyMem(pbytes, addr buff[0], nbytes)
method readLine*(s: BufferStream,
limit = 0,
sep = "\r\n"):
Future[string] {.async, gcsafe.} =
sep = "\r\n"):
Future[string] {.async.} =
## Read one line from read-only stream ``rstream``, where ``"line"`` is a
## sequence of bytes ending with ``sep`` (default is ``"\r\n"``).
##
@ -146,6 +174,7 @@ method readLine*(s: BufferStream,
##
## If ``limit`` more then 0, then result string will be limited to ``limit``
## bytes.
##
result = ""
var lim = if limit <= 0: -1 else: limit
var state = 0
@ -170,14 +199,15 @@ method readLine*(s: BufferStream,
method readOnce*(s: BufferStream,
pbytes: pointer,
nbytes: int):
Future[int] {.async, gcsafe.} =
Future[int] {.async.} =
## Perform one read operation on read-only stream ``rstream``.
##
## If internal buffer is not empty, ``nbytes`` bytes will be transferred from
## internal buffer, otherwise it will wait until some bytes will be received.
##
if s.readBuf.len == 0:
await s.requestReadBytes()
var len = if nbytes > s.readBuf.len: s.readBuf.len else: nbytes
await s.readExactly(pbytes, len)
result = len
@ -186,7 +216,7 @@ method readUntil*(s: BufferStream,
pbytes: pointer,
nbytes: int,
sep: seq[byte]):
Future[int] {.async, gcsafe.} =
Future[int] {.async.} =
## Read data from the read-only stream ``rstream`` until separator ``sep`` is
## found.
##
@ -200,6 +230,7 @@ method readUntil*(s: BufferStream,
## will raise ``LPStreamLimitError``.
##
## Procedure returns actual number of bytes read.
##
var
dest = cast[ptr UncheckedArray[byte]](pbytes)
state = 0
@ -231,22 +262,22 @@ method readUntil*(s: BufferStream,
else:
s.shrink(datalen)
method write*(s: BufferStream,
pbytes: pointer,
nbytes: int): Future[void]
{.gcsafe.} =
method write*(s: BufferStream,
pbytes: pointer,
nbytes: int): Future[void] =
## Consume (discard) all bytes (n <= 0) or ``n`` bytes from read-only stream
## ``rstream``.
##
## Return number of bytes actually consumed (discarded).
##
var buf: seq[byte] = newSeq[byte](nbytes)
copyMem(addr buf[0], pbytes, nbytes)
result = s.writeHandler(buf)
if not isNil(s.writeHandler):
result = s.writeHandler(buf)
method write*(s: BufferStream,
msg: string,
msglen = -1): Future[void]
{.gcsafe.} =
msglen = -1): Future[void] =
## Write string ``sbytes`` of length ``msglen`` to writer stream ``wstream``.
##
## String ``sbytes`` must not be zero-length.
@ -254,14 +285,15 @@ method write*(s: BufferStream,
## If ``msglen < 0`` whole string ``sbytes`` will be writen to stream.
## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to
## stream.
##
var buf = ""
shallowCopy(buf, if msglen > 0: msg[0..<msglen] else: msg)
result = s.writeHandler(cast[seq[byte]](buf))
if not isNil(s.writeHandler):
result = s.writeHandler(cast[seq[byte]](buf))
method write*(s: BufferStream,
msg: seq[byte],
msglen = -1): Future[void]
{.gcsafe.} =
msglen = -1): Future[void] =
## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer
## stream ``wstream``.
##
@ -270,13 +302,56 @@ method write*(s: BufferStream,
## If ``msglen < 0`` whole sequence ``sbytes`` will be writen to stream.
## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to
## stream.
##
var buf: seq[byte]
shallowCopy(buf, if msglen > 0: msg[0..<msglen] else: msg)
result = s.writeHandler(buf)
if not isNil(s.writeHandler):
result = s.writeHandler(buf)
method close*(s: BufferStream) {.async, gcsafe.} =
proc pipe*(s: BufferStream,
target: BufferStream): BufferStream =
## pipe the write end of this stream to
## be the source of the target stream
##
## Note that this only works with the LPStream
## interface methods `read*` and `write` are
## piped.
##
if s.isPiped:
raise newAlreadyPipedError()
s.isPiped = true
let oldHandler = target.writeHandler
proc handler(data: seq[byte]) {.async, closure.} =
if not isNil(oldHandler):
await oldHandler(data)
# if we're piping to self,
# then add the data to the
# buffer directly and fire
# the read event
if s == target:
for b in data:
s.readBuf.addLast(b)
# notify main loop of available
# data
s.dataReadEvent.fire()
else:
await target.pushTo(data)
s.writeHandler = handler
result = target
proc `|`*(s: BufferStream, target: BufferStream): BufferStream =
## pipe operator to make piping less verbose
pipe(s, target)
method close*(s: BufferStream) {.async.} =
## close the stream and clear the buffer
for r in s.readReqs:
r.cancel()
s.dataReadEvent.fire()
s.readBuf.clear()
s.closed = true
s.closeEvent.fire()
s.isClosed = true

View File

@ -26,40 +26,63 @@ proc newChronosStream*(server: StreamServer,
result.client = client
result.reader = newAsyncStreamReader(client)
result.writer = newAsyncStreamWriter(client)
result.closed = false
result.closeEvent = newAsyncEvent()
method read*(s: ChronosStream, n = -1): Future[seq[byte]] {.async.} =
if s.reader.atEof:
raise newLPStreamClosedError()
method read*(s: ChronosStream, n = -1): Future[seq[byte]] {.async, gcsafe.} =
try:
result = await s.reader.read(n)
except AsyncStreamReadError as exc:
raise newLPStreamReadError(exc.par)
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method readExactly*(s: ChronosStream,
pbytes: pointer,
nbytes: int): Future[void] {.async, gcsafe.} =
nbytes: int): Future[void] {.async.} =
if s.reader.atEof:
raise newLPStreamClosedError()
try:
await s.reader.readExactly(pbytes, nbytes)
except AsyncStreamIncompleteError:
raise newLPStreamIncompleteError()
except AsyncStreamReadError as exc:
raise newLPStreamReadError(exc.par)
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method readLine*(s: ChronosStream, limit = 0, sep = "\r\n"): Future[string] {.async.} =
if s.reader.atEof:
raise newLPStreamClosedError()
method readLine*(s: ChronosStream, limit = 0, sep = "\r\n"): Future[string] {.async, gcsafe.} =
try:
result = await s.reader.readLine(limit, sep)
except AsyncStreamReadError as exc:
raise newLPStreamReadError(exc.par)
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} =
if s.reader.atEof:
raise newLPStreamClosedError()
method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async, gcsafe.} =
try:
result = await s.reader.readOnce(pbytes, nbytes)
except AsyncStreamReadError as exc:
raise newLPStreamReadError(exc.par)
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method readUntil*(s: ChronosStream,
pbytes: pointer,
nbytes: int,
sep: seq[byte]): Future[int] {.async, gcsafe.} =
sep: seq[byte]): Future[int] {.async.} =
if s.reader.atEof:
raise newLPStreamClosedError()
try:
result = await s.reader.readUntil(pbytes, nbytes, sep)
except AsyncStreamIncompleteError:
@ -68,36 +91,62 @@ method readUntil*(s: ChronosStream,
raise newLPStreamLimitError()
except LPStreamReadError as exc:
raise newLPStreamReadError(exc.par)
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method write*(s: ChronosStream, pbytes: pointer, nbytes: int) {.async.} =
if s.writer.atEof:
raise newLPStreamClosedError()
method write*(s: ChronosStream, pbytes: pointer, nbytes: int) {.async, gcsafe.} =
try:
await s.writer.write(pbytes, nbytes)
except AsyncStreamWriteError as exc:
raise newLPStreamWriteError(exc.par)
except AsyncStreamIncompleteError:
raise newLPStreamIncompleteError()
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method write*(s: ChronosStream, msg: string, msglen = -1) {.async.} =
if s.writer.atEof:
raise newLPStreamClosedError()
method write*(s: ChronosStream, msg: string, msglen = -1) {.async, gcsafe.} =
try:
await s.writer.write(msg, msglen)
except AsyncStreamWriteError as exc:
raise newLPStreamWriteError(exc.par)
except AsyncStreamIncompleteError:
raise newLPStreamIncompleteError()
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method write*(s: ChronosStream, msg: seq[byte], msglen = -1) {.async.} =
if s.writer.atEof:
raise newLPStreamClosedError()
method write*(s: ChronosStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
try:
await s.writer.write(msg, msglen)
except AsyncStreamWriteError as exc:
raise newLPStreamWriteError(exc.par)
except AsyncStreamIncompleteError:
raise newLPStreamIncompleteError()
except AsyncStreamIncorrectError as exc:
raise newLPStreamIncorrectError(exc.msg)
method close*(s: ChronosStream) {.async, gcsafe.} =
method closed*(s: ChronosStream): bool {.inline.} =
# TODO: we might only need to check for reader's EOF
result = s.reader.atEof()
method close*(s: ChronosStream) {.async.} =
if not s.closed:
trace "shutting down server", address = $s.client.remoteAddress()
await s.writer.finish()
await s.writer.closeWait()
await s.reader.closeWait()
await s.client.closeWait()
s.closed = true
trace "shutting chronos stream", address = $s.client.remoteAddress()
if not s.writer.closed():
await s.writer.closeWait()
if not s.reader.closed():
await s.reader.closeWait()
if not s.client.closed():
await s.client.closeWait()
s.closeEvent.fire()

View File

@ -11,7 +11,8 @@ import chronos
type
LPStream* = ref object of RootObj
closed*: bool
isClosed*: bool
closeEvent*: AsyncEvent
LPStreamError* = object of CatchableError
LPStreamIncompleteError* = object of LPStreamError
@ -47,40 +48,43 @@ proc newLPStreamIncorrectError*(m: string): ref Exception {.inline.} =
proc newLPStreamClosedError*(): ref Exception {.inline.} =
result = newException(LPStreamClosedError, "Stream closed!")
method closed*(s: LPStream): bool {.base, inline.} =
s.isClosed
method read*(s: LPStream, n = -1): Future[seq[byte]]
{.base, async, gcsafe.} =
{.base, async.} =
doAssert(false, "not implemented!")
method readExactly*(s: LPStream, pbytes: pointer, nbytes: int): Future[void]
{.base, async, gcsafe.} =
{.base, async.} =
doAssert(false, "not implemented!")
method readLine*(s: LPStream, limit = 0, sep = "\r\n"): Future[string]
{.base, async, gcsafe.} =
{.base, async.} =
doAssert(false, "not implemented!")
method readOnce*(s: LPStream, pbytes: pointer, nbytes: int): Future[int]
{.base, async, gcsafe.} =
{.base, async.} =
doAssert(false, "not implemented!")
method readUntil*(s: LPStream,
pbytes: pointer, nbytes: int,
sep: seq[byte]): Future[int]
{.base, async, gcsafe.} =
{.base, async.} =
doAssert(false, "not implemented!")
method write*(s: LPStream, pbytes: pointer, nbytes: int)
{.base, async, gcsafe.} =
{.base, async.} =
doAssert(false, "not implemented!")
method write*(s: LPStream, msg: string, msglen = -1)
{.base, async, gcsafe.} =
{.base, async.} =
doAssert(false, "not implemented!")
method write*(s: LPStream, msg: seq[byte], msglen = -1)
{.base, async, gcsafe.} =
{.base, async.} =
doAssert(false, "not implemented!")
method close*(s: LPStream)
{.base, async, gcsafe.} =
{.base, async.} =
doAssert(false, "not implemented!")

View File

@ -11,13 +11,11 @@ import tables, sequtils, options, strformat
import chronos, chronicles
import connection,
transports/transport,
stream/lpstream,
multistream,
protocols/protocol,
protocols/secure/secure,
protocols/secure/plaintext, # for plain text
peerinfo,
multiaddress,
protocols/identify,
protocols/pubsub/pubsub,
muxers/muxer,
@ -26,6 +24,12 @@ import connection,
logScope:
topic = "Switch"
#TODO: General note - use a finite state machine to manage the different
# steps of connections establishing and upgrading. This makes everything
# more robust and less prone to ordering attacks - i.e. muxing can come if
# and only if the channel has been secured (i.e. if a secure manager has been
# previously provided)
type
NoPubSubException = object of CatchableError
@ -48,7 +52,6 @@ proc newNoPubSubException(): ref Exception {.inline.} =
proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
## secure the incoming connection
# plaintext for now, doesn't do anything
let managers = toSeq(s.secureManagers.keys)
if managers.len == 0:
raise newException(CatchableError, "No secure managers registered!")
@ -62,20 +65,21 @@ proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
proc identify(s: Switch, conn: Connection): Future[PeerInfo] {.async, gcsafe.} =
## identify the connection
result = conn.peerInfo
try:
if (await s.ms.select(conn, s.identity.codec)):
let info = await s.identity.identify(conn, conn.peerInfo)
if info.pubKey.isSome:
result.peerId = some(PeerID.init(info.pubKey.get())) # we might not have a peerId at all
trace "identify: identified remote peer", peer = result.id
if info.addrs.len > 0:
result.addrs = info.addrs
if info.protos.len > 0:
result.protocols = info.protos
trace "identify: identified remote peer ", peer = result.peerId.get().pretty
except IdentityInvalidMsgError as exc:
error "identify: invalid message", msg = exc.msg
except IdentityNoMatchError as exc:
@ -100,22 +104,23 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} =
muxer.streamHandler = s.streamHandler
# new stream for identify
let stream = await muxer.newStream()
var stream = await muxer.newStream()
let handlerFut = muxer.handle()
# add muxer handler cleanup proc
handlerFut.addCallback(
proc(udata: pointer = nil) {.gcsafe.} =
trace "mux: Muxer handler completed for peer ",
peer = conn.peerInfo.peerId.get().pretty
trace "muxer handler completed for peer",
peer = conn.peerInfo.id
)
# do identify first, so that we have a
# PeerInfo in case we didn't before
conn.peerInfo = await s.identify(stream)
await stream.close() # close idenity stream
trace "connection's peerInfo", peerInfo = conn.peerInfo.peerId
await stream.close() # close identify stream
trace "connection's peerInfo", peerInfo = conn.peerInfo
# store it in muxed connections if we have a peer for it
# TODO: We should make sure that this are cleaned up properly
@ -123,43 +128,42 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} =
# happen once secio is in place, but still something to keep
# in mind
if conn.peerInfo.peerId.isSome:
trace "adding muxer for peer", peer = conn.peerInfo.peerId.get().pretty
s.muxed[conn.peerInfo.peerId.get().pretty] = muxer
trace "adding muxer for peer", peer = conn.peerInfo.id
s.muxed[conn.peerInfo.id] = muxer
proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} =
if conn.peerInfo.peerId.isSome:
let id = conn.peerInfo.peerId.get().pretty
if s.muxed.contains(id):
await s.muxed[id].close
if s.connections.contains(id):
let id = conn.peerInfo.id
trace "cleaning up connection for peer", peerId = id
if id in s.muxed:
await s.muxed[id].close()
s.muxed.del(id)
if id in s.connections:
await s.connections[id].close()
s.connections.del(id)
proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Option[Connection]] {.async, gcsafe.} =
# if there is a muxer for the connection
# use it instead to create a muxed stream
if s.muxed.contains(peerInfo.peerId.get().pretty):
trace "connection is muxed, retriving muxer and setting up a stream"
let muxer = s.muxed[peerInfo.peerId.get().pretty]
if peerInfo.id in s.muxed:
trace "connection is muxed, setting up a stream"
let muxer = s.muxed[peerInfo.id]
let conn = await muxer.newStream()
result = some(conn)
proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
trace "handling connection", conn = conn
result = conn
## perform upgrade flow
if result.peerInfo.peerId.isSome:
let id = result.peerInfo.peerId.get().pretty
if s.connections.contains(id):
# if we already have a connection for this peer,
# close the incoming connection and return the
# existing one
await result.close()
return s.connections[id]
s.connections[id] = result
result = await s.secure(conn) # secure the connection
# don't mux/secure twise
if conn.peerInfo.peerId.isSome and
conn.peerInfo.id in s.muxed:
return
result = await s.secure(result) # secure the connection
await s.mux(result) # mux it if possible
s.connections[conn.peerInfo.id] = result
proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
trace "upgrading incoming connection"
@ -192,42 +196,57 @@ proc dial*(s: Switch,
peer: PeerInfo,
proto: string = ""):
Future[Connection] {.async.} =
trace "dialing peer", peer = peer.peerId.get().pretty
let id = peer.id
trace "dialing peer", peer = id
for t in s.transports: # for each transport
for a in peer.addrs: # for each address
if t.handles(a): # check if it can dial it
result = await t.dial(a)
# make sure to assign the peer to the connection
result.peerInfo = peer
if id notin s.connections:
trace "dialing address", address = $a
result = await t.dial(a)
# make sure to assign the peer to the connection
result.peerInfo = peer
result = await s.upgradeOutgoing(result)
result.closeEvent.wait().addCallback(
proc(udata: pointer) =
asyncCheck s.cleanupConn(result)
)
let stream = await s.getMuxedStream(peer)
if stream.isSome:
trace "connection is muxed, return muxed stream"
result = stream.get()
if proto.len > 0 and not result.closed:
let stream = await s.getMuxedStream(peer)
if stream.isSome:
trace "connection is muxed, return muxed stream"
result = stream.get()
trace "attempting to select remote", proto = proto
trace "dial: attempting to select remote ", proto = proto
if not (await s.ms.select(result, proto)):
error "dial: Unable to select protocol: ", proto = proto
raise newException(CatchableError,
&"Unable to select protocol: {proto}")
if not (await s.ms.select(result, proto)):
error "unable to select protocol: ", proto = proto
raise newException(CatchableError,
&"unable to select protocol: {proto}")
break # don't dial more than one addr on the same transport
proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} =
if isNil(proto.handler):
raise newException(CatchableError,
raise newException(CatchableError,
"Protocol has to define a handle method or proc")
if proto.codec.len == 0:
raise newException(CatchableError,
raise newException(CatchableError,
"Protocol has to define a codec string")
s.ms.addHandler(proto.codec, proto)
proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
trace "starting switch"
proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} =
try:
await s.upgradeIncoming(conn) # perform upgrade on incoming connection
except CatchableError as exc:
trace "exception occured", exc = exc.msg
finally:
await conn.close()
await s.cleanupConn(conn)
var startFuts: seq[Future[void]]
@ -237,10 +256,13 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
var server = await t.listen(a, handle)
s.peerInfo.addrs[i] = t.ma # update peer's address
startFuts.add(server)
result = startFuts # listen for incoming connections
proc stop*(s: Switch) {.async.} =
await allFutures(toSeq(s.connections.values).mapIt(it.close()))
trace "stopping switch"
await allFutures(toSeq(s.connections.values).mapIt(s.cleanupConn(it)))
await allFutures(s.transports.mapIt(it.close()))
proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =
@ -253,14 +275,14 @@ proc subscribe*(s: Switch, topic: string, handler: TopicHandler): Future[void] {
## subscribe to a pubsub topic
if s.pubSub.isNone:
raise newNoPubSubException()
result = s.pubSub.get().subscribe(topic, handler)
proc unsubscribe*(s: Switch, topics: seq[TopicPair]): Future[void] {.gcsafe.} =
## unsubscribe from topics
if s.pubSub.isNone:
raise newNoPubSubException()
result = s.pubSub.get().unsubscribe(topics)
proc publish*(s: Switch, topic: string, data: seq[byte]): Future[void] {.gcsafe.} =

View File

@ -7,7 +7,7 @@
## This file may not be copied, modified, or distributed except according to
## those terms.
import chronos, chronicles
import chronos, chronicles, sequtils
import transport,
../wire,
../connection,
@ -78,5 +78,5 @@ method dial*(t: TcpTransport,
result = await t.connHandler(t.server, client, true)
method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} =
## TODO: implement logic to properly discriminat TCP multiaddrs
true
if procCall Transport(t).handles(address):
result = address.protocols.filterIt( it == multiCodec("tcp") ).len > 0

View File

@ -9,8 +9,7 @@
import sequtils
import chronos, chronicles
import ../peerinfo,
../connection,
import ../connection,
../multiaddress,
../multicodec
@ -62,9 +61,10 @@ method upgrade*(t: Transport) {.base, async, gcsafe.} =
method handles*(t: Transport, address: MultiAddress): bool {.base, gcsafe.} =
## check if transport supportes the multiaddress
# TODO: this should implement generic logic that would use the multicodec
# declared in the multicodec field and set by each individual transport
discard
# by default we skip circuit addresses to avoid
# having to repeat the check in every transport
address.protocols.filterIt( it == multiCodec("p2p-circuit") ).len == 0
method localAddress*(t: Transport): MultiAddress {.base, gcsafe.} =
## get the local address of the transport in case started with 0.0.0.0:0

View File

@ -53,7 +53,17 @@ proc initVBuffer*(): VBuffer =
## Initialize empty VBuffer.
result.buffer = newSeqOfCap[byte](128)
proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) =
proc writePBVarint*(vb: var VBuffer, value: PBSomeUVarint) =
## Write ``value`` as variable unsigned integer.
var length = 0
var v = value and cast[type(value)](0xFFFF_FFFF_FFFF_FFFF)
vb.buffer.setLen(len(vb.buffer) + vsizeof(v))
let res = PB.putUVarint(toOpenArray(vb.buffer, vb.offset, len(vb.buffer) - 1),
length, v)
doAssert(res == VarintStatus.Success)
vb.offset += length
proc writeLPVarint*(vb: var VBuffer, value: LPSomeUVarint) =
## Write ``value`` as variable unsigned integer.
var length = 0
# LibP2P varint supports only 63 bits.
@ -64,6 +74,9 @@ proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) =
doAssert(res == VarintStatus.Success)
vb.offset += length
proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) =
writeLPVarint(vb, value)
proc writeSeq*[T: byte|char](vb: var VBuffer, value: openarray[T]) =
## Write array ``value`` to buffer ``vb``, value will be prefixed with
## varint length of the array.

View File

@ -1,4 +1,4 @@
import unittest, deques, sequtils, strformat
import unittest, strformat
import chronos
import ../libp2p/stream/bufferstream
@ -220,7 +220,6 @@ suite "BufferStream":
test "reads should happen in order":
proc testWritePtr(): Future[bool] {.async.} =
var count = 1
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let buff = newBufferStream(writeHandler, 10)
check buff.len == 0
@ -245,3 +244,199 @@ suite "BufferStream":
check:
waitFor(testWritePtr()) == true
test "pipe two streams without the `pipe` or `|` helpers":
proc pipeTest(): Future[bool] {.async.} =
proc writeHandler1(data: seq[byte]) {.async, gcsafe.}
proc writeHandler2(data: seq[byte]) {.async, gcsafe.}
var buf1 = newBufferStream(writeHandler1)
var buf2 = newBufferStream(writeHandler2)
proc writeHandler1(data: seq[byte]) {.async, gcsafe.} =
var msg = cast[string](data)
check msg == "Hello!"
await buf2.pushTo(data)
proc writeHandler2(data: seq[byte]) {.async, gcsafe.} =
var msg = cast[string](data)
check msg == "Hello!"
await buf1.pushTo(data)
var res1: seq[byte] = newSeq[byte](7)
var readFut1 = buf1.readExactly(addr res1[0], 7)
var res2: seq[byte] = newSeq[byte](7)
var readFut2 = buf2.readExactly(addr res2[0], 7)
await buf1.pushTo(cast[seq[byte]]("Hello2!"))
await buf2.pushTo(cast[seq[byte]]("Hello1!"))
await allFutures(readFut1, readFut2)
check:
res1 == cast[seq[byte]]("Hello2!")
res2 == cast[seq[byte]]("Hello1!")
result = true
check:
waitFor(pipeTest()) == true
test "pipe A -> B":
proc pipeTest(): Future[bool] {.async.} =
var buf1 = newBufferStream()
var buf2 = buf1.pipe(newBufferStream())
var res1: seq[byte] = newSeq[byte](7)
var readFut = buf2.readExactly(addr res1[0], 7)
await buf1.write(cast[seq[byte]]("Hello1!"))
await readFut
check:
res1 == cast[seq[byte]]("Hello1!")
result = true
check:
waitFor(pipeTest()) == true
test "pipe A -> B and B -> A":
proc pipeTest(): Future[bool] {.async.} =
var buf1 = newBufferStream()
var buf2 = newBufferStream()
buf1 = buf1.pipe(buf2).pipe(buf1)
var res1: seq[byte] = newSeq[byte](7)
var readFut1 = buf1.readExactly(addr res1[0], 7)
var res2: seq[byte] = newSeq[byte](7)
var readFut2 = buf2.readExactly(addr res2[0], 7)
await buf1.write(cast[seq[byte]]("Hello1!"))
await buf2.write(cast[seq[byte]]("Hello2!"))
await allFutures(readFut1, readFut2)
check:
res1 == cast[seq[byte]]("Hello2!")
res2 == cast[seq[byte]]("Hello1!")
result = true
check:
waitFor(pipeTest()) == true
test "pipe A -> A (echo)":
proc pipeTest(): Future[bool] {.async.} =
var buf1 = newBufferStream()
buf1 = buf1.pipe(buf1)
proc reader(): Future[seq[byte]] = buf1.read(6)
proc writer(): Future[void] = buf1.write(cast[seq[byte]]("Hello!"))
var writerFut = writer()
var readerFut = reader()
await writerFut
check:
(await readerFut) == cast[seq[byte]]("Hello!")
result = true
check:
waitFor(pipeTest()) == true
test "pipe with `|` operator - A -> B":
proc pipeTest(): Future[bool] {.async.} =
var buf1 = newBufferStream()
var buf2 = buf1 | newBufferStream()
var res1: seq[byte] = newSeq[byte](7)
var readFut = buf2.readExactly(addr res1[0], 7)
await buf1.write(cast[seq[byte]]("Hello1!"))
await readFut
check:
res1 == cast[seq[byte]]("Hello1!")
result = true
check:
waitFor(pipeTest()) == true
test "pipe with `|` operator - A -> B and B -> A":
proc pipeTest(): Future[bool] {.async.} =
var buf1 = newBufferStream()
var buf2 = newBufferStream()
buf1 = buf1 | buf2 | buf1
var res1: seq[byte] = newSeq[byte](7)
var readFut1 = buf1.readExactly(addr res1[0], 7)
var res2: seq[byte] = newSeq[byte](7)
var readFut2 = buf2.readExactly(addr res2[0], 7)
await buf1.write(cast[seq[byte]]("Hello1!"))
await buf2.write(cast[seq[byte]]("Hello2!"))
await allFutures(readFut1, readFut2)
check:
res1 == cast[seq[byte]]("Hello2!")
res2 == cast[seq[byte]]("Hello1!")
result = true
check:
waitFor(pipeTest()) == true
test "pipe with `|` operator - A -> A (echo)":
proc pipeTest(): Future[bool] {.async.} =
var buf1 = newBufferStream()
buf1 = buf1 | buf1
proc reader(): Future[seq[byte]] = buf1.read(6)
proc writer(): Future[void] = buf1.write(cast[seq[byte]]("Hello!"))
var writerFut = writer()
var readerFut = reader()
await writerFut
check:
(await readerFut) == cast[seq[byte]]("Hello!")
result = true
check:
waitFor(pipeTest()) == true
# TODO: Need to implement deadlock prevention when
# piping to self
test "pipe deadlock":
proc pipeTest(): Future[bool] {.async.} =
var buf1 = newBufferStream(size = 5)
buf1 = buf1 | buf1
var count = 30000
proc reader() {.async.} =
while count > 0:
discard await buf1.read(7)
proc writer() {.async.} =
while count > 0:
await buf1.write(cast[seq[byte]]("Hello2!"))
count.dec
var writerFut = writer()
var readerFut = reader()
await allFutures(readerFut, writerFut)
result = true
check:
waitFor(pipeTest()) == true

View File

@ -274,25 +274,15 @@ suite "Mplex":
expect LPStreamClosedError:
waitFor(testClosedForWrite())
test "half closed - channel should close for read":
proc testClosedForRead(): Future[void] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
await chann.closedByRemote()
asyncDiscard chann.read()
expect LPStreamClosedError:
waitFor(testClosedForRead())
test "half closed - channel should close for read after eof":
test "half closed - channel should close for read by remote":
proc testClosedForRead(): Future[void] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
await chann.pushTo(cast[seq[byte]]("Hello!"))
await chann.close()
let msg = await chann.read()
asyncDiscard chann.read()
await chann.closedByRemote()
discard await chann.read() # this should work, since there is data in the buffer
discard await chann.read() # this should throw
expect LPStreamClosedError:
waitFor(testClosedForRead())
@ -312,7 +302,7 @@ suite "Mplex":
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
await chann.reset()
asyncDiscard chann.read()
await chann.write(cast[seq[byte]]("Hello!"))
expect LPStreamClosedError:
waitFor(testResetWrite())

View File

@ -1,4 +1,4 @@
import unittest, strutils, sequtils, sugar, strformat, options
import unittest, strutils, sequtils, strformat, options
import chronos
import ../libp2p/connection,
../libp2p/multistream,
@ -51,7 +51,8 @@ method write*(s: TestSelectStream, msg: seq[byte], msglen = -1)
method write*(s: TestSelectStream, msg: string, msglen = -1)
{.async, gcsafe.} = discard
method close(s: TestSelectStream) {.async, gcsafe.} = s.closed = true
method close(s: TestSelectStream) {.async, gcsafe.} =
s.isClosed = true
proc newTestSelectStream(): TestSelectStream =
new result
@ -97,7 +98,8 @@ method write*(s: TestLsStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
method write*(s: TestLsStream, msg: string, msglen = -1)
{.async, gcsafe.} = discard
method close(s: TestLsStream) {.async, gcsafe.} = s.closed = true
method close(s: TestLsStream) {.async, gcsafe.} =
s.isClosed = true
proc newTestLsStream(ls: LsHandler): TestLsStream {.gcsafe.} =
new result
@ -143,7 +145,8 @@ method write*(s: TestNaStream, msg: string, msglen = -1) {.async, gcsafe.} =
if s.step == 4:
await s.na(msg)
method close(s: TestNaStream) {.async, gcsafe.} = s.closed = true
method close(s: TestNaStream) {.async, gcsafe.} =
s.isClosed = true
proc newTestNaStream(na: NaHandler): TestNaStream =
new result

View File

@ -2,5 +2,11 @@ import unittest
import testvarint, testbase32, testbase58, testbase64
import testrsa, testecnist, tested25519, testsecp256k1, testcrypto
import testmultibase, testmultihash, testmultiaddress, testcid, testpeer
import testtransport, testmultistream, testbufferstream,
testmplex, testidentify, testswitch, testpubsub
import testtransport,
testmultistream,
testbufferstream,
testidentify,
testswitch,
testpubsub,
testmplex

View File

@ -1,5 +1,5 @@
import unittest, tables, options
import chronos, chronicles
import chronos
import ../libp2p/[switch,
multistream,
protocols/identify,
@ -36,7 +36,7 @@ method init(p: TestProto) {.gcsafe.} =
suite "Switch":
test "e2e use switch":
proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) =
proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) {.gcsafe.}=
let seckey = PrivateKey.random(RSA)
var peerInfo: PeerInfo
peerInfo.peerId = some(PeerID.init(seckey))
@ -50,7 +50,11 @@ suite "Switch":
let transports = @[Transport(newTransport(TcpTransport))]
let muxers = [(MplexCodec, mplexProvider)].toTable()
let secureManagers = [(SecioCodec, Secure(newSecio(seckey)))].toTable()
let switch = newSwitch(peerInfo, transports, identify, muxers, secureManagers)
let switch = newSwitch(peerInfo,
transports,
identify,
muxers,
secureManagers)
result = (switch, peerInfo)
proc testSwitch(): Future[bool] {.async, gcsafe.} =