Feat/conn cleanup (#41)
Backporting proper connection cleanup from #36 to align with latest chronos changes. * add close event * use proper varint encoding * add proper channel cleanup in mplex * add connection cleanup in secio * tidy up * add dollar operator * fix tests * don't close connections prematurely * handle closing streams properly * misc * implement address filtering logic * adding pipe tests * don't use gcsafe if not needed * misc * proper connection cleanup and stream muxing * re-enable pubsub tests
This commit is contained in:
parent
1df16bdbce
commit
903e79ede1
|
@ -7,7 +7,7 @@
|
||||||
## This file may not be copied, modified, or distributed except according to
|
## This file may not be copied, modified, or distributed except according to
|
||||||
## those terms.
|
## those terms.
|
||||||
|
|
||||||
import chronos, options, chronicles
|
import chronos, chronicles
|
||||||
import peerinfo,
|
import peerinfo,
|
||||||
multiaddress,
|
multiaddress,
|
||||||
stream/lpstream,
|
stream/lpstream,
|
||||||
|
@ -26,15 +26,28 @@ type
|
||||||
InvalidVarintException = object of LPStreamError
|
InvalidVarintException = object of LPStreamError
|
||||||
|
|
||||||
proc newInvalidVarintException*(): ref InvalidVarintException =
|
proc newInvalidVarintException*(): ref InvalidVarintException =
|
||||||
result = newException(InvalidVarintException, "unable to prase varint")
|
newException(InvalidVarintException, "unable to prase varint")
|
||||||
|
|
||||||
proc newConnection*(stream: LPStream): Connection =
|
proc newConnection*(stream: LPStream): Connection =
|
||||||
## create a new Connection for the specified async reader/writer
|
## create a new Connection for the specified async reader/writer
|
||||||
new result
|
new result
|
||||||
result.stream = stream
|
result.stream = stream
|
||||||
|
result.closeEvent = newAsyncEvent()
|
||||||
|
|
||||||
|
# bind stream's close event to connection's close
|
||||||
|
# to ensure correct close propagation
|
||||||
|
let this = result
|
||||||
|
if not isNil(result.stream.closeEvent):
|
||||||
|
result.stream.closeEvent.wait().
|
||||||
|
addCallback(
|
||||||
|
proc (udata: pointer) =
|
||||||
|
if not this.closed:
|
||||||
|
trace "closing this connection because wrapped stream closed"
|
||||||
|
asyncCheck this.close()
|
||||||
|
)
|
||||||
|
|
||||||
method read*(s: Connection, n = -1): Future[seq[byte]] {.gcsafe.} =
|
method read*(s: Connection, n = -1): Future[seq[byte]] {.gcsafe.} =
|
||||||
result = s.stream.read(n)
|
s.stream.read(n)
|
||||||
|
|
||||||
method readExactly*(s: Connection,
|
method readExactly*(s: Connection,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
|
@ -44,13 +57,13 @@ method readExactly*(s: Connection,
|
||||||
|
|
||||||
method readLine*(s: Connection,
|
method readLine*(s: Connection,
|
||||||
limit = 0,
|
limit = 0,
|
||||||
sep = "\r\n"):
|
sep = "\r\n"):
|
||||||
Future[string] {.gcsafe.} =
|
Future[string] {.gcsafe.} =
|
||||||
s.stream.readLine(limit, sep)
|
s.stream.readLine(limit, sep)
|
||||||
|
|
||||||
method readOnce*(s: Connection,
|
method readOnce*(s: Connection,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
nbytes: int):
|
nbytes: int):
|
||||||
Future[int] {.gcsafe.} =
|
Future[int] {.gcsafe.} =
|
||||||
s.stream.readOnce(pbytes, nbytes)
|
s.stream.readOnce(pbytes, nbytes)
|
||||||
|
|
||||||
|
@ -61,15 +74,15 @@ method readUntil*(s: Connection,
|
||||||
Future[int] {.gcsafe.} =
|
Future[int] {.gcsafe.} =
|
||||||
s.stream.readUntil(pbytes, nbytes, sep)
|
s.stream.readUntil(pbytes, nbytes, sep)
|
||||||
|
|
||||||
method write*(s: Connection,
|
method write*(s: Connection,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
nbytes: int):
|
nbytes: int):
|
||||||
Future[void] {.gcsafe.} =
|
Future[void] {.gcsafe.} =
|
||||||
s.stream.write(pbytes, nbytes)
|
s.stream.write(pbytes, nbytes)
|
||||||
|
|
||||||
method write*(s: Connection,
|
method write*(s: Connection,
|
||||||
msg: string,
|
msg: string,
|
||||||
msglen = -1):
|
msglen = -1):
|
||||||
Future[void] {.gcsafe.} =
|
Future[void] {.gcsafe.} =
|
||||||
s.stream.write(msg, msglen)
|
s.stream.write(msg, msglen)
|
||||||
|
|
||||||
|
@ -79,9 +92,20 @@ method write*(s: Connection,
|
||||||
Future[void] {.gcsafe.} =
|
Future[void] {.gcsafe.} =
|
||||||
s.stream.write(msg, msglen)
|
s.stream.write(msg, msglen)
|
||||||
|
|
||||||
|
method closed*(s: Connection): bool =
|
||||||
|
if isNil(s.stream):
|
||||||
|
return false
|
||||||
|
|
||||||
|
result = s.stream.closed
|
||||||
|
|
||||||
method close*(s: Connection) {.async, gcsafe.} =
|
method close*(s: Connection) {.async, gcsafe.} =
|
||||||
await s.stream.close()
|
trace "closing connection"
|
||||||
s.closed = true
|
if not s.closed:
|
||||||
|
if not isNil(s.stream) and not s.stream.closed:
|
||||||
|
await s.stream.close()
|
||||||
|
s.closeEvent.fire()
|
||||||
|
s.isClosed = true
|
||||||
|
trace "connection closed", closed = s.closed
|
||||||
|
|
||||||
proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} =
|
proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} =
|
||||||
## read lenght prefixed msg
|
## read lenght prefixed msg
|
||||||
|
@ -100,21 +124,23 @@ proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} =
|
||||||
raise newInvalidVarintException()
|
raise newInvalidVarintException()
|
||||||
result.setLen(size)
|
result.setLen(size)
|
||||||
if size > 0.uint:
|
if size > 0.uint:
|
||||||
|
trace "reading exact bytes from stream", size = size
|
||||||
await s.readExactly(addr result[0], int(size))
|
await s.readExactly(addr result[0], int(size))
|
||||||
except LPStreamIncompleteError, LPStreamReadError:
|
except LPStreamIncompleteError as exc:
|
||||||
trace "remote connection closed", exc = getCurrentExceptionMsg()
|
trace "remote connection ended unexpectedly", exc = exc.msg
|
||||||
|
except LPStreamReadError as exc:
|
||||||
|
trace "couldn't read from stream", exc = exc.msg
|
||||||
|
|
||||||
proc writeLp*(s: Connection, msg: string | seq[byte]): Future[void] {.gcsafe.} =
|
proc writeLp*(s: Connection, msg: string | seq[byte]): Future[void] {.gcsafe.} =
|
||||||
## write lenght prefixed
|
## write lenght prefixed
|
||||||
var buf = initVBuffer()
|
var buf = initVBuffer()
|
||||||
buf.writeSeq(msg)
|
buf.writeSeq(msg)
|
||||||
buf.finish()
|
buf.finish()
|
||||||
result = s.write(buf.buffer)
|
s.write(buf.buffer)
|
||||||
|
|
||||||
method getObservedAddrs*(c: Connection): Future[MultiAddress] {.base, async, gcsafe.} =
|
method getObservedAddrs*(c: Connection): Future[MultiAddress] {.base, async, gcsafe.} =
|
||||||
## get resolved multiaddresses for the connection
|
## get resolved multiaddresses for the connection
|
||||||
result = c.observedAddrs
|
result = c.observedAddrs
|
||||||
|
|
||||||
proc `$`*(conn: Connection): string =
|
proc `$`*(conn: Connection): string =
|
||||||
if conn.peerInfo.peerId.isSome:
|
result = $(conn.peerInfo)
|
||||||
result = $(conn.peerInfo.peerId.get())
|
|
||||||
|
|
|
@ -855,7 +855,7 @@ proc connect*(api: DaemonAPI, peer: PeerID,
|
||||||
timeout))
|
timeout))
|
||||||
pb.withMessage() do:
|
pb.withMessage() do:
|
||||||
discard
|
discard
|
||||||
finally:
|
except:
|
||||||
await api.closeConnection(transp)
|
await api.closeConnection(transp)
|
||||||
|
|
||||||
proc disconnect*(api: DaemonAPI, peer: PeerID) {.async.} =
|
proc disconnect*(api: DaemonAPI, peer: PeerID) {.async.} =
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
## This file may not be copied, modified, or distributed except according to
|
## This file may not be copied, modified, or distributed except according to
|
||||||
## those terms.
|
## those terms.
|
||||||
|
|
||||||
import sequtils, strutils, strformat
|
import strutils
|
||||||
import chronos, chronicles
|
import chronos, chronicles
|
||||||
import connection,
|
import connection,
|
||||||
varint,
|
|
||||||
vbuffer,
|
vbuffer,
|
||||||
protocols/protocol
|
protocols/protocol,
|
||||||
|
stream/lpstream
|
||||||
|
|
||||||
logScope:
|
logScope:
|
||||||
topic = "Multistream"
|
topic = "Multistream"
|
||||||
|
@ -56,16 +56,16 @@ proc select*(m: MultisteamSelect,
|
||||||
trace "selecting proto", proto = proto
|
trace "selecting proto", proto = proto
|
||||||
await conn.writeLp((proto[0] & "\n")) # select proto
|
await conn.writeLp((proto[0] & "\n")) # select proto
|
||||||
|
|
||||||
result = cast[string](await conn.readLp()) # read ms header
|
result = cast[string]((await conn.readLp())) # read ms header
|
||||||
result.removeSuffix("\n")
|
result.removeSuffix("\n")
|
||||||
if result != Codec:
|
if result != Codec:
|
||||||
trace "handshake failed", codec = result
|
trace "handshake failed", codec = result.toHex()
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
if proto.len() == 0: # no protocols, must be a handshake call
|
if proto.len() == 0: # no protocols, must be a handshake call
|
||||||
return
|
return
|
||||||
|
|
||||||
result = cast[string](await conn.readLp()) # read the first proto
|
result = cast[string]((await conn.readLp())) # read the first proto
|
||||||
trace "reading first requested proto"
|
trace "reading first requested proto"
|
||||||
result.removeSuffix("\n")
|
result.removeSuffix("\n")
|
||||||
if result == proto[0]:
|
if result == proto[0]:
|
||||||
|
@ -76,7 +76,7 @@ proc select*(m: MultisteamSelect,
|
||||||
trace "selecting one of several protos"
|
trace "selecting one of several protos"
|
||||||
for p in proto[1..<proto.len()]:
|
for p in proto[1..<proto.len()]:
|
||||||
await conn.writeLp((p & "\n")) # select proto
|
await conn.writeLp((p & "\n")) # select proto
|
||||||
result = cast[string](await conn.readLp()) # read the first proto
|
result = cast[string]((await conn.readLp())) # read the first proto
|
||||||
result.removeSuffix("\n")
|
result.removeSuffix("\n")
|
||||||
if result == p:
|
if result == p:
|
||||||
trace "selected protocol", protocol = result
|
trace "selected protocol", protocol = result
|
||||||
|
@ -102,7 +102,7 @@ proc list*(m: MultisteamSelect,
|
||||||
await conn.write(m.ls) # send ls
|
await conn.write(m.ls) # send ls
|
||||||
|
|
||||||
var list = newSeq[string]()
|
var list = newSeq[string]()
|
||||||
let ms = cast[string](await conn.readLp())
|
let ms = cast[string]((await conn.readLp()))
|
||||||
for s in ms.split("\n"):
|
for s in ms.split("\n"):
|
||||||
if s.len() > 0:
|
if s.len() > 0:
|
||||||
list.add(s)
|
list.add(s)
|
||||||
|
@ -111,8 +111,10 @@ proc list*(m: MultisteamSelect,
|
||||||
|
|
||||||
proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} =
|
proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} =
|
||||||
trace "handle: starting multistream handling"
|
trace "handle: starting multistream handling"
|
||||||
while not conn.closed:
|
try:
|
||||||
var ms = cast[string](await conn.readLp())
|
while not conn.closed:
|
||||||
|
await sleepAsync(1.millis)
|
||||||
|
var ms = cast[string]((await conn.readLp()))
|
||||||
ms.removeSuffix("\n")
|
ms.removeSuffix("\n")
|
||||||
|
|
||||||
trace "handle: got request for ", ms
|
trace "handle: got request for ", ms
|
||||||
|
@ -142,11 +144,15 @@ proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} =
|
||||||
try:
|
try:
|
||||||
await h.protocol.handler(conn, ms)
|
await h.protocol.handler(conn, ms)
|
||||||
return
|
return
|
||||||
except Exception as exc:
|
except CatchableError as exc:
|
||||||
warn "exception while handling ", msg = exc.msg
|
warn "exception while handling", msg = exc.msg
|
||||||
return
|
return
|
||||||
warn "no handlers for ", protocol = ms
|
warn "no handlers for ", protocol = ms
|
||||||
await conn.write(m.na)
|
await conn.write(m.na)
|
||||||
|
except CatchableError as exc:
|
||||||
|
trace "exception occured", exc = exc.msg
|
||||||
|
finally:
|
||||||
|
trace "leaving multistream loop"
|
||||||
|
|
||||||
proc addHandler*[T: LPProtocol](m: MultisteamSelect,
|
proc addHandler*[T: LPProtocol](m: MultisteamSelect,
|
||||||
codec: string,
|
codec: string,
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
## This file may not be copied, modified, or distributed except according to
|
## This file may not be copied, modified, or distributed except according to
|
||||||
## those terms.
|
## those terms.
|
||||||
|
|
||||||
import chronos, options, sequtils, strformat
|
import chronos, options
|
||||||
import nimcrypto/utils, chronicles
|
import nimcrypto/utils, chronicles
|
||||||
import types,
|
import types,
|
||||||
../../connection,
|
../../connection,
|
||||||
|
@ -29,31 +29,33 @@ proc readMplexVarint(conn: Connection): Future[Option[uint]] {.async, gcsafe.} =
|
||||||
varint: uint
|
varint: uint
|
||||||
length: int
|
length: int
|
||||||
res: VarintStatus
|
res: VarintStatus
|
||||||
var buffer = newSeq[byte](10)
|
buffer = newSeq[byte](10)
|
||||||
|
|
||||||
result = none(uint)
|
result = none(uint)
|
||||||
try:
|
try:
|
||||||
for i in 0..<len(buffer):
|
for i in 0..<len(buffer):
|
||||||
await conn.readExactly(addr buffer[i], 1)
|
if not conn.closed:
|
||||||
res = LP.getUVarint(buffer.toOpenArray(0, i), length, varint)
|
await conn.readExactly(addr buffer[i], 1)
|
||||||
if res == VarintStatus.Success:
|
res = PB.getUVarint(buffer.toOpenArray(0, i), length, varint)
|
||||||
return some(varint)
|
if res == VarintStatus.Success:
|
||||||
|
return some(varint)
|
||||||
if res != VarintStatus.Success:
|
if res != VarintStatus.Success:
|
||||||
raise newInvalidVarintException()
|
raise newInvalidVarintException()
|
||||||
except LPStreamIncompleteError:
|
except LPStreamIncompleteError as exc:
|
||||||
trace "unable to read varint", exc = getCurrentExceptionMsg()
|
trace "unable to read varint", exc = exc.msg
|
||||||
|
|
||||||
proc readMsg*(conn: Connection): Future[Option[Msg]] {.async, gcsafe.} =
|
proc readMsg*(conn: Connection): Future[Option[Msg]] {.async, gcsafe.} =
|
||||||
let headerVarint = await conn.readMplexVarint()
|
let headerVarint = await conn.readMplexVarint()
|
||||||
if headerVarint.isNone:
|
if headerVarint.isNone:
|
||||||
return
|
return
|
||||||
|
|
||||||
trace "readMsg: read header varint ", varint = headerVarint
|
trace "read header varint", varint = headerVarint
|
||||||
|
|
||||||
let dataLenVarint = await conn.readMplexVarint()
|
let dataLenVarint = await conn.readMplexVarint()
|
||||||
var data: seq[byte]
|
var data: seq[byte]
|
||||||
if dataLenVarint.isSome and dataLenVarint.get() > 0.uint:
|
if dataLenVarint.isSome and dataLenVarint.get() > 0.uint:
|
||||||
trace "readMsg: read size varint ", varint = dataLenVarint
|
|
||||||
data = await conn.read(dataLenVarint.get().int)
|
data = await conn.read(dataLenVarint.get().int)
|
||||||
|
trace "read size varint", varint = dataLenVarint
|
||||||
|
|
||||||
let header = headerVarint.get()
|
let header = headerVarint.get()
|
||||||
result = some((header shr 3, MessageType(header and 0x7), data))
|
result = some((header shr 3, MessageType(header and 0x7), data))
|
||||||
|
@ -64,11 +66,13 @@ proc writeMsg*(conn: Connection,
|
||||||
data: seq[byte] = @[]) {.async, gcsafe.} =
|
data: seq[byte] = @[]) {.async, gcsafe.} =
|
||||||
## write lenght prefixed
|
## write lenght prefixed
|
||||||
var buf = initVBuffer()
|
var buf = initVBuffer()
|
||||||
let header = (id shl 3 or ord(msgType).uint)
|
buf.writePBVarint(id shl 3 or ord(msgType).uint)
|
||||||
buf.writeVarint(id shl 3 or ord(msgType).uint)
|
buf.writePBVarint(data.len().uint) # size should be always sent
|
||||||
buf.writeVarint(data.len().uint) # size should be always sent
|
|
||||||
buf.finish()
|
buf.finish()
|
||||||
await conn.write(buf.buffer & data)
|
try:
|
||||||
|
await conn.write(buf.buffer & data)
|
||||||
|
except LPStreamIncompleteError as exc:
|
||||||
|
trace "unable to send message", exc = exc.msg
|
||||||
|
|
||||||
proc writeMsg*(conn: Connection,
|
proc writeMsg*(conn: Connection,
|
||||||
id: uint,
|
id: uint,
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
## This file may not be copied, modified, or distributed except according to
|
## This file may not be copied, modified, or distributed except according to
|
||||||
## those terms.
|
## those terms.
|
||||||
|
|
||||||
import strformat
|
|
||||||
import chronos, chronicles
|
import chronos, chronicles
|
||||||
import types,
|
import types,
|
||||||
coder,
|
coder,
|
||||||
|
@ -52,99 +51,110 @@ proc newChannel*(id: uint,
|
||||||
result.asyncLock = newAsyncLock()
|
result.asyncLock = newAsyncLock()
|
||||||
|
|
||||||
let chan = result
|
let chan = result
|
||||||
proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} =
|
proc writeHandler(data: seq[byte]): Future[void] {.async.} =
|
||||||
# writes should happen in sequence
|
# writes should happen in sequence
|
||||||
await chan.asyncLock.acquire()
|
await chan.asyncLock.acquire()
|
||||||
trace "writeHandler: sending data ", data = data.toHex(), id = chan.id
|
trace "sending data ", data = data.toHex(),
|
||||||
|
id = chan.id,
|
||||||
|
initiator = chan.initiator
|
||||||
|
|
||||||
await conn.writeMsg(chan.id, chan.msgCode, data) # write header
|
await conn.writeMsg(chan.id, chan.msgCode, data) # write header
|
||||||
chan.asyncLock.release()
|
chan.asyncLock.release()
|
||||||
|
|
||||||
result.initBufferStream(writeHandler, size)
|
result.initBufferStream(writeHandler, size)
|
||||||
|
|
||||||
proc closeMessage(s: LPChannel) {.async, gcsafe.} =
|
proc closeMessage(s: LPChannel) {.async.} =
|
||||||
await s.conn.writeMsg(s.id, s.closeCode) # write header
|
await s.conn.writeMsg(s.id, s.closeCode) # write header
|
||||||
|
|
||||||
proc closed*(s: LPChannel): bool =
|
|
||||||
s.closedLocal and s.closedLocal
|
|
||||||
|
|
||||||
proc closedByRemote*(s: LPChannel) {.async.} =
|
proc closedByRemote*(s: LPChannel) {.async.} =
|
||||||
s.closedRemote = true
|
s.closedRemote = true
|
||||||
|
|
||||||
proc cleanUp*(s: LPChannel): Future[void] =
|
proc cleanUp*(s: LPChannel): Future[void] =
|
||||||
|
# method which calls the underlying buffer's `close`
|
||||||
|
# method used instead of `close` since it's overloaded to
|
||||||
|
# simulate half-closed streams
|
||||||
result = procCall close(BufferStream(s))
|
result = procCall close(BufferStream(s))
|
||||||
|
|
||||||
|
proc open*(s: LPChannel): Future[void] =
|
||||||
|
s.conn.writeMsg(s.id, MessageType.New, s.name)
|
||||||
|
|
||||||
method close*(s: LPChannel) {.async, gcsafe.} =
|
method close*(s: LPChannel) {.async, gcsafe.} =
|
||||||
s.closedLocal = true
|
s.closedLocal = true
|
||||||
await s.closeMessage()
|
await s.closeMessage()
|
||||||
|
|
||||||
proc resetMessage(s: LPChannel) {.async, gcsafe.} =
|
proc resetMessage(s: LPChannel) {.async.} =
|
||||||
await s.conn.writeMsg(s.id, s.resetCode)
|
await s.conn.writeMsg(s.id, s.resetCode)
|
||||||
|
|
||||||
proc resetByRemote*(s: LPChannel) {.async, gcsafe.} =
|
proc resetByRemote*(s: LPChannel) {.async.} =
|
||||||
await allFutures(s.close(), s.closedByRemote())
|
await allFutures(s.close(), s.closedByRemote())
|
||||||
s.isReset = true
|
s.isReset = true
|
||||||
|
|
||||||
proc reset*(s: LPChannel) {.async.} =
|
proc reset*(s: LPChannel) {.async.} =
|
||||||
await allFutures(s.resetMessage(), s.resetByRemote())
|
await allFutures(s.resetMessage(), s.resetByRemote())
|
||||||
|
|
||||||
proc isReadEof(s: LPChannel): bool =
|
method closed*(s: LPChannel): bool =
|
||||||
bool((s.closedRemote or s.closedLocal) and s.len() < 1)
|
result = s.closedRemote and s.len == 0
|
||||||
|
|
||||||
proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] {.gcsafe.} =
|
proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] =
|
||||||
if s.closedRemote:
|
if s.closedRemote or s.isReset:
|
||||||
raise newLPStreamClosedError()
|
raise newLPStreamClosedError()
|
||||||
|
trace "pushing data to channel", data = data.toHex(),
|
||||||
|
id = s.id,
|
||||||
|
initiator = s.initiator
|
||||||
|
|
||||||
result = procCall pushTo(BufferStream(s), data)
|
result = procCall pushTo(BufferStream(s), data)
|
||||||
|
|
||||||
method read*(s: LPChannel, n = -1): Future[seq[byte]] {.gcsafe.} =
|
method read*(s: LPChannel, n = -1): Future[seq[byte]] =
|
||||||
if s.isReadEof():
|
if s.closed or s.isReset:
|
||||||
raise newLPStreamClosedError()
|
raise newLPStreamClosedError()
|
||||||
|
|
||||||
result = procCall read(BufferStream(s), n)
|
result = procCall read(BufferStream(s), n)
|
||||||
|
|
||||||
method readExactly*(s: LPChannel,
|
method readExactly*(s: LPChannel,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
nbytes: int):
|
nbytes: int):
|
||||||
Future[void] {.gcsafe.} =
|
Future[void] =
|
||||||
if s.isReadEof():
|
if s.closed or s.isReset:
|
||||||
raise newLPStreamClosedError()
|
raise newLPStreamClosedError()
|
||||||
result = procCall readExactly(BufferStream(s), pbytes, nbytes)
|
result = procCall readExactly(BufferStream(s), pbytes, nbytes)
|
||||||
|
|
||||||
method readLine*(s: LPChannel,
|
method readLine*(s: LPChannel,
|
||||||
limit = 0,
|
limit = 0,
|
||||||
sep = "\r\n"):
|
sep = "\r\n"):
|
||||||
Future[string] {.gcsafe.} =
|
Future[string] =
|
||||||
if s.isReadEof():
|
if s.closed or s.isReset:
|
||||||
raise newLPStreamClosedError()
|
raise newLPStreamClosedError()
|
||||||
result = procCall readLine(BufferStream(s), limit, sep)
|
result = procCall readLine(BufferStream(s), limit, sep)
|
||||||
|
|
||||||
method readOnce*(s: LPChannel,
|
method readOnce*(s: LPChannel,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
nbytes: int):
|
nbytes: int):
|
||||||
Future[int] {.gcsafe.} =
|
Future[int] =
|
||||||
if s.isReadEof():
|
if s.closed or s.isReset:
|
||||||
raise newLPStreamClosedError()
|
raise newLPStreamClosedError()
|
||||||
result = procCall readOnce(BufferStream(s), pbytes, nbytes)
|
result = procCall readOnce(BufferStream(s), pbytes, nbytes)
|
||||||
|
|
||||||
method readUntil*(s: LPChannel,
|
method readUntil*(s: LPChannel,
|
||||||
pbytes: pointer, nbytes: int,
|
pbytes: pointer, nbytes: int,
|
||||||
sep: seq[byte]):
|
sep: seq[byte]):
|
||||||
Future[int] {.gcsafe.} =
|
Future[int] =
|
||||||
if s.isReadEof():
|
if s.closed or s.isReset:
|
||||||
raise newLPStreamClosedError()
|
raise newLPStreamClosedError()
|
||||||
result = procCall readOnce(BufferStream(s), pbytes, nbytes)
|
result = procCall readOnce(BufferStream(s), pbytes, nbytes)
|
||||||
|
|
||||||
method write*(s: LPChannel,
|
method write*(s: LPChannel,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
nbytes: int): Future[void] {.gcsafe.} =
|
nbytes: int): Future[void] =
|
||||||
if s.closedLocal:
|
if s.closedLocal or s.isReset:
|
||||||
raise newLPStreamClosedError()
|
raise newLPStreamClosedError()
|
||||||
result = procCall write(BufferStream(s), pbytes, nbytes)
|
result = procCall write(BufferStream(s), pbytes, nbytes)
|
||||||
|
|
||||||
method write*(s: LPChannel, msg: string, msglen = -1) {.async, gcsafe.} =
|
method write*(s: LPChannel, msg: string, msglen = -1) {.async.} =
|
||||||
if s.closedLocal:
|
if s.closedLocal or s.isReset:
|
||||||
raise newLPStreamClosedError()
|
raise newLPStreamClosedError()
|
||||||
result = procCall write(BufferStream(s), msg, msglen)
|
result = procCall write(BufferStream(s), msg, msglen)
|
||||||
|
|
||||||
method write*(s: LPChannel, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
|
method write*(s: LPChannel, msg: seq[byte], msglen = -1) {.async.} =
|
||||||
if s.closedLocal:
|
if s.closedLocal or s.isReset:
|
||||||
raise newLPStreamClosedError()
|
raise newLPStreamClosedError()
|
||||||
result = procCall write(BufferStream(s), msg, msglen)
|
result = procCall write(BufferStream(s), msg, msglen)
|
||||||
|
|
|
@ -11,16 +11,14 @@
|
||||||
## Timeouts and message limits are still missing
|
## Timeouts and message limits are still missing
|
||||||
## they need to be added ASAP
|
## they need to be added ASAP
|
||||||
|
|
||||||
import tables, sequtils, options, strformat
|
import tables, sequtils, options
|
||||||
import chronos, chronicles
|
import chronos, chronicles
|
||||||
import coder, types, lpchannel,
|
import ../muxer,
|
||||||
../muxer,
|
|
||||||
../../varint,
|
|
||||||
../../connection,
|
../../connection,
|
||||||
../../vbuffer,
|
../../stream/lpstream,
|
||||||
../../protocols/protocol,
|
coder,
|
||||||
../../stream/bufferstream,
|
types,
|
||||||
../../stream/lpstream
|
lpchannel
|
||||||
|
|
||||||
logScope:
|
logScope:
|
||||||
topic = "Mplex"
|
topic = "Mplex"
|
||||||
|
@ -34,9 +32,11 @@ type
|
||||||
|
|
||||||
proc getChannelList(m: Mplex, initiator: bool): var Table[uint, LPChannel] =
|
proc getChannelList(m: Mplex, initiator: bool): var Table[uint, LPChannel] =
|
||||||
if initiator:
|
if initiator:
|
||||||
result = m.remote
|
trace "picking local channels", initiator = initiator
|
||||||
else:
|
|
||||||
result = m.local
|
result = m.local
|
||||||
|
else:
|
||||||
|
trace "picking remote channels", initiator = initiator
|
||||||
|
result = m.remote
|
||||||
|
|
||||||
proc newStreamInternal*(m: Mplex,
|
proc newStreamInternal*(m: Mplex,
|
||||||
initiator: bool = true,
|
initiator: bool = true,
|
||||||
|
@ -45,17 +45,28 @@ proc newStreamInternal*(m: Mplex,
|
||||||
Future[LPChannel] {.async, gcsafe.} =
|
Future[LPChannel] {.async, gcsafe.} =
|
||||||
## create new channel/stream
|
## create new channel/stream
|
||||||
let id = if initiator: m.currentId.inc(); m.currentId else: chanId
|
let id = if initiator: m.currentId.inc(); m.currentId else: chanId
|
||||||
|
trace "creating new channel", channelId = id, initiator = initiator
|
||||||
result = newChannel(id, m.connection, initiator, name)
|
result = newChannel(id, m.connection, initiator, name)
|
||||||
m.getChannelList(initiator)[id] = result
|
m.getChannelList(initiator)[id] = result
|
||||||
|
|
||||||
|
proc cleanupChann(m: Mplex, chann: LPChannel, initiator: bool) {.async, inline.} =
|
||||||
|
## call the channel's `close` to signal the
|
||||||
|
## remote that the channel is closing
|
||||||
|
if not isNil(chann) and not chann.closed:
|
||||||
|
await chann.close()
|
||||||
|
await chann.cleanUp()
|
||||||
|
m.getChannelList(initiator).del(chann.id)
|
||||||
|
trace "cleaned up channel", id = chann.id
|
||||||
|
|
||||||
method handle*(m: Mplex) {.async, gcsafe.} =
|
method handle*(m: Mplex) {.async, gcsafe.} =
|
||||||
trace "starting mplex main loop"
|
trace "starting mplex main loop"
|
||||||
try:
|
try:
|
||||||
while not m.connection.closed:
|
while not m.connection.closed:
|
||||||
|
trace "waiting for data"
|
||||||
let msg = await m.connection.readMsg()
|
let msg = await m.connection.readMsg()
|
||||||
if msg.isNone:
|
if msg.isNone:
|
||||||
# TODO: allow poll with timeout to avoid using `sleepAsync`
|
# TODO: allow poll with timeout to avoid using `sleepAsync`
|
||||||
await sleepAsync(10.millis)
|
await sleepAsync(1.millis)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
let (id, msgType, data) = msg.get()
|
let (id, msgType, data) = msg.get()
|
||||||
|
@ -63,8 +74,11 @@ method handle*(m: Mplex) {.async, gcsafe.} =
|
||||||
var channel: LPChannel
|
var channel: LPChannel
|
||||||
if MessageType(msgType) != MessageType.New:
|
if MessageType(msgType) != MessageType.New:
|
||||||
let channels = m.getChannelList(initiator)
|
let channels = m.getChannelList(initiator)
|
||||||
if not channels.contains(id):
|
if id notin channels:
|
||||||
trace "handle: Channel with id and msg type ", id = id, msg = msgType
|
trace "Channel not found, skipping", id = id,
|
||||||
|
initiator = initiator,
|
||||||
|
msg = msgType
|
||||||
|
await sleepAsync(1.millis)
|
||||||
continue
|
continue
|
||||||
channel = channels[id]
|
channel = channels[id]
|
||||||
|
|
||||||
|
@ -72,36 +86,44 @@ method handle*(m: Mplex) {.async, gcsafe.} =
|
||||||
of MessageType.New:
|
of MessageType.New:
|
||||||
let name = cast[string](data)
|
let name = cast[string](data)
|
||||||
channel = await m.newStreamInternal(false, id, name)
|
channel = await m.newStreamInternal(false, id, name)
|
||||||
trace "handle: created channel ", id = id, name = name
|
trace "created channel", id = id, name = name, inititator = true
|
||||||
if not isNil(m.streamHandler):
|
if not isNil(m.streamHandler):
|
||||||
let stream = newConnection(channel)
|
let stream = newConnection(channel)
|
||||||
stream.peerInfo = m.connection.peerInfo
|
stream.peerInfo = m.connection.peerInfo
|
||||||
let handlerFut = m.streamHandler(stream)
|
|
||||||
|
|
||||||
# channel cleanup routine
|
# cleanup channel once handler is finished
|
||||||
proc cleanUpChan(udata: pointer) {.gcsafe.} =
|
# stream.closeEvent.wait().addCallback(
|
||||||
if handlerFut.finished:
|
# proc(udata: pointer) =
|
||||||
channel.close().addCallback(
|
# asyncCheck cleanupChann(m, channel, initiator))
|
||||||
proc(udata: pointer) =
|
|
||||||
channel.cleanUp()
|
asyncCheck m.streamHandler(stream)
|
||||||
.addCallback(proc(udata: pointer) =
|
|
||||||
trace "handle: cleaned up channel ", id = id))
|
|
||||||
handlerFut.addCallback(cleanUpChan)
|
|
||||||
continue
|
continue
|
||||||
of MessageType.MsgIn, MessageType.MsgOut:
|
of MessageType.MsgIn, MessageType.MsgOut:
|
||||||
trace "handle: pushing data to channel ", id = id, msgType = msgType
|
trace "pushing data to channel", id = id,
|
||||||
|
initiator = initiator,
|
||||||
|
msgType = msgType
|
||||||
|
|
||||||
await channel.pushTo(data)
|
await channel.pushTo(data)
|
||||||
of MessageType.CloseIn, MessageType.CloseOut:
|
of MessageType.CloseIn, MessageType.CloseOut:
|
||||||
trace "handle: closing channel ", id = id, msgType = msgType
|
trace "closing channel", id = id,
|
||||||
|
initiator = initiator,
|
||||||
|
msgType = msgType
|
||||||
|
|
||||||
await channel.closedByRemote()
|
await channel.closedByRemote()
|
||||||
m.getChannelList(initiator).del(id)
|
m.getChannelList(initiator).del(id)
|
||||||
of MessageType.ResetIn, MessageType.ResetOut:
|
of MessageType.ResetIn, MessageType.ResetOut:
|
||||||
trace "handle: resetting channel ", id = id
|
trace "resetting channel", id = id,
|
||||||
|
initiator = initiator,
|
||||||
|
msgType = msgType
|
||||||
|
|
||||||
await channel.resetByRemote()
|
await channel.resetByRemote()
|
||||||
|
m.getChannelList(initiator).del(id)
|
||||||
break
|
break
|
||||||
except:
|
except CatchableError as exc:
|
||||||
error "exception occurred", exception = getCurrentExceptionMsg()
|
trace "exception occurred", exception = exc.msg
|
||||||
finally:
|
finally:
|
||||||
|
trace "stopping mplex main loop"
|
||||||
await m.connection.close()
|
await m.connection.close()
|
||||||
|
|
||||||
proc newMplex*(conn: Connection,
|
proc newMplex*(conn: Connection,
|
||||||
|
@ -112,13 +134,20 @@ proc newMplex*(conn: Connection,
|
||||||
result.remote = initTable[uint, LPChannel]()
|
result.remote = initTable[uint, LPChannel]()
|
||||||
result.local = initTable[uint, LPChannel]()
|
result.local = initTable[uint, LPChannel]()
|
||||||
|
|
||||||
|
let m = result
|
||||||
|
conn.closeEvent.wait().addCallback(
|
||||||
|
proc(udata: pointer) =
|
||||||
|
asyncCheck m.close()
|
||||||
|
)
|
||||||
|
|
||||||
method newStream*(m: Mplex, name: string = ""): Future[Connection] {.async, gcsafe.} =
|
method newStream*(m: Mplex, name: string = ""): Future[Connection] {.async, gcsafe.} =
|
||||||
let channel = await m.newStreamInternal()
|
let channel = await m.newStreamInternal()
|
||||||
await m.connection.writeMsg(channel.id, MessageType.New, name)
|
# TODO: open the channel (this should be lazy)
|
||||||
|
await channel.open()
|
||||||
result = newConnection(channel)
|
result = newConnection(channel)
|
||||||
result.peerInfo = m.connection.peerInfo
|
result.peerInfo = m.connection.peerInfo
|
||||||
|
|
||||||
method close*(m: Mplex) {.async, gcsafe.} =
|
method close*(m: Mplex) {.async, gcsafe.} =
|
||||||
await allFutures(@[allFutures(toSeq(m.remote.values).mapIt(it.close())),
|
trace "closing mplex muxer"
|
||||||
allFutures(toSeq(m.local.values).mapIt(it.close()))])
|
await allFutures(@[allFutures(toSeq(m.remote.values).mapIt(it.reset())),
|
||||||
m.connection.reset()
|
allFutures(toSeq(m.local.values).mapIt(it.reset()))])
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
## those terms.
|
## those terms.
|
||||||
|
|
||||||
import chronos
|
import chronos
|
||||||
import ../../connection
|
|
||||||
|
|
||||||
const MaxMsgSize* = 1 shl 20 # 1mb
|
const MaxMsgSize* = 1 shl 20 # 1mb
|
||||||
const MaxChannels* = 1000
|
const MaxChannels* = 1000
|
||||||
|
|
|
@ -10,7 +10,27 @@
|
||||||
import options
|
import options
|
||||||
import peer, multiaddress
|
import peer, multiaddress
|
||||||
|
|
||||||
type PeerInfo* = object of RootObj
|
type
|
||||||
peerId*: Option[PeerID]
|
PeerInfo* = object of RootObj
|
||||||
addrs*: seq[MultiAddress]
|
peerId*: Option[PeerID]
|
||||||
protocols*: seq[string]
|
addrs*: seq[MultiAddress]
|
||||||
|
protocols*: seq[string]
|
||||||
|
|
||||||
|
proc id*(p: PeerInfo): string =
|
||||||
|
if p.peerId.isSome:
|
||||||
|
result = p.peerId.get().pretty
|
||||||
|
|
||||||
|
proc `$`*(p: PeerInfo): string =
|
||||||
|
if p.peerId.isSome:
|
||||||
|
result.add("PeerID: ")
|
||||||
|
result.add(p.id & "\n")
|
||||||
|
|
||||||
|
if p.addrs.len > 0:
|
||||||
|
result.add("Peer Addrs: ")
|
||||||
|
for a in p.addrs:
|
||||||
|
result.add($a & "\n")
|
||||||
|
|
||||||
|
if p.protocols.len > 0:
|
||||||
|
result.add("Protocols: ")
|
||||||
|
for proto in p.protocols:
|
||||||
|
result.add(proto & "\n")
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
## This file may not be copied, modified, or distributed except according to
|
## This file may not be copied, modified, or distributed except according to
|
||||||
## those terms.
|
## those terms.
|
||||||
|
|
||||||
import options, strformat
|
import options
|
||||||
import chronos, chronicles
|
import chronos, chronicles
|
||||||
import ../protobuf/minprotobuf,
|
import ../protobuf/minprotobuf,
|
||||||
../peerinfo,
|
../peerinfo,
|
||||||
|
@ -115,14 +115,14 @@ method init*(p: Identify) =
|
||||||
trace "handling identify request"
|
trace "handling identify request"
|
||||||
var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs())
|
var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs())
|
||||||
await conn.writeLp(pb.buffer)
|
await conn.writeLp(pb.buffer)
|
||||||
|
# await conn.close() #TODO: investigate why this breaks
|
||||||
|
|
||||||
p.handler = handle
|
p.handler = handle
|
||||||
p.codec = IdentifyCodec
|
p.codec = IdentifyCodec
|
||||||
|
|
||||||
proc identify*(p: Identify,
|
proc identify*(p: Identify,
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
remotePeerInfo: PeerInfo):
|
remotePeerInfo: PeerInfo): Future[IdentifyInfo] {.async, gcsafe.} =
|
||||||
Future[IdentifyInfo] {.async.} =
|
|
||||||
var message = await conn.readLp()
|
var message = await conn.readLp()
|
||||||
if len(message) == 0:
|
if len(message) == 0:
|
||||||
trace "identify: Invalid or empty message received!"
|
trace "identify: Invalid or empty message received!"
|
||||||
|
@ -139,7 +139,7 @@ proc identify*(p: Identify,
|
||||||
if peer != remotePeerInfo.peerId.get():
|
if peer != remotePeerInfo.peerId.get():
|
||||||
trace "Peer ids don't match",
|
trace "Peer ids don't match",
|
||||||
remote = peer.pretty(),
|
remote = peer.pretty(),
|
||||||
local = remotePeerInfo.peerId.get().pretty()
|
local = remotePeerInfo.id
|
||||||
|
|
||||||
raise newException(IdentityNoMatchError,
|
raise newException(IdentityNoMatchError,
|
||||||
"Peer ids don't match")
|
"Peer ids don't match")
|
||||||
|
@ -149,5 +149,4 @@ proc identify*(p: Identify,
|
||||||
proc push*(p: Identify, conn: Connection) {.async.} =
|
proc push*(p: Identify, conn: Connection) {.async.} =
|
||||||
await conn.write(IdentifyPushCodec)
|
await conn.write(IdentifyPushCodec)
|
||||||
var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs())
|
var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs())
|
||||||
let length = pb.getLen()
|
|
||||||
await conn.writeLp(pb.buffer)
|
await conn.writeLp(pb.buffer)
|
||||||
|
|
|
@ -8,9 +8,7 @@
|
||||||
## those terms.
|
## those terms.
|
||||||
|
|
||||||
import chronos
|
import chronos
|
||||||
import ../connection,
|
import ../connection
|
||||||
../peerinfo,
|
|
||||||
../multiaddress
|
|
||||||
|
|
||||||
type
|
type
|
||||||
LPProtoHandler* = proc (conn: Connection,
|
LPProtoHandler* = proc (conn: Connection,
|
||||||
|
|
|
@ -14,6 +14,7 @@ import rpcmsg,
|
||||||
../../peer,
|
../../peer,
|
||||||
../../peerinfo,
|
../../peerinfo,
|
||||||
../../connection,
|
../../connection,
|
||||||
|
../../stream/lpstream,
|
||||||
../../crypto/crypto,
|
../../crypto/crypto,
|
||||||
../../protobuf/minprotobuf
|
../../protobuf/minprotobuf
|
||||||
|
|
||||||
|
@ -45,7 +46,7 @@ proc handle*(p: PubSubPeer) {.async, gcsafe.} =
|
||||||
trace "Decoded msg from peer", peer = p.id, msg = msg
|
trace "Decoded msg from peer", peer = p.id, msg = msg
|
||||||
await p.handler(p, @[msg])
|
await p.handler(p, @[msg])
|
||||||
except:
|
except:
|
||||||
error "An exception occured while processing pubsub rpc requests", exc = getCurrentExceptionMsg()
|
trace "An exception occured while processing pubsub rpc requests", exc = getCurrentExceptionMsg()
|
||||||
finally:
|
finally:
|
||||||
trace "closing connection to pubsub peer", peer = p.id
|
trace "closing connection to pubsub peer", peer = p.id
|
||||||
await p.conn.close()
|
await p.conn.close()
|
||||||
|
|
|
@ -8,8 +8,7 @@
|
||||||
## those terms.
|
## those terms.
|
||||||
|
|
||||||
import chronos
|
import chronos
|
||||||
import secure,
|
import secure, ../../connection
|
||||||
../../connection
|
|
||||||
|
|
||||||
const PlainTextCodec* = "/plaintext/1.0.0"
|
const PlainTextCodec* = "/plaintext/1.0.0"
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,12 @@
|
||||||
## at your option.
|
## at your option.
|
||||||
## This file may not be copied, modified, or distributed except according to
|
## This file may not be copied, modified, or distributed except according to
|
||||||
## those terms.
|
## those terms.
|
||||||
|
import options
|
||||||
import chronos, chronicles
|
import chronos, chronicles
|
||||||
import nimcrypto/[sysrand, hmac, sha2, sha, hash, rijndael, twofish, bcmode]
|
import nimcrypto/[sysrand, hmac, sha2, sha, hash, rijndael, twofish, bcmode]
|
||||||
import secure,
|
import secure,
|
||||||
../../connection,
|
../../connection,
|
||||||
|
../../stream/lpstream,
|
||||||
../../crypto/crypto,
|
../../crypto/crypto,
|
||||||
../../crypto/ecnist,
|
../../crypto/ecnist,
|
||||||
../../protobuf/minprotobuf,
|
../../protobuf/minprotobuf,
|
||||||
|
@ -60,7 +62,6 @@ type
|
||||||
ctxsha1: HMAC[sha1]
|
ctxsha1: HMAC[sha1]
|
||||||
|
|
||||||
SecureConnection* = ref object of Connection
|
SecureConnection* = ref object of Connection
|
||||||
conn*: Connection
|
|
||||||
writerMac: SecureMac
|
writerMac: SecureMac
|
||||||
readerMac: SecureMac
|
readerMac: SecureMac
|
||||||
writerCoder: SecureCipher
|
writerCoder: SecureCipher
|
||||||
|
@ -176,13 +177,13 @@ proc readMessage*(sconn: SecureConnection): Future[seq[byte]] {.async.} =
|
||||||
## Read message from channel secure connection ``sconn``.
|
## Read message from channel secure connection ``sconn``.
|
||||||
try:
|
try:
|
||||||
var buf = newSeq[byte](4)
|
var buf = newSeq[byte](4)
|
||||||
await sconn.conn.readExactly(addr buf[0], 4)
|
await sconn.readExactly(addr buf[0], 4)
|
||||||
let length = (int(buf[0]) shl 24) or (int(buf[1]) shl 16) or
|
let length = (int(buf[0]) shl 24) or (int(buf[1]) shl 16) or
|
||||||
(int(buf[2]) shl 8) or (int(buf[3]))
|
(int(buf[2]) shl 8) or (int(buf[3]))
|
||||||
trace "Recieved message header", header = toHex(buf), length = length
|
trace "Recieved message header", header = toHex(buf), length = length
|
||||||
if length <= SecioMaxMessageSize:
|
if length <= SecioMaxMessageSize:
|
||||||
buf.setLen(length)
|
buf.setLen(length)
|
||||||
await sconn.conn.readExactly(addr buf[0], length)
|
await sconn.readExactly(addr buf[0], length)
|
||||||
trace "Received message body", length = length,
|
trace "Received message body", length = length,
|
||||||
buffer = toHex(buf)
|
buffer = toHex(buf)
|
||||||
if sconn.macCheckAndDecode(buf):
|
if sconn.macCheckAndDecode(buf):
|
||||||
|
@ -213,21 +214,27 @@ proc writeMessage*(sconn: SecureConnection, message: seq[byte]) {.async.} =
|
||||||
msg[3] = byte(length and 0xFF)
|
msg[3] = byte(length and 0xFF)
|
||||||
trace "Writing message", message = toHex(msg)
|
trace "Writing message", message = toHex(msg)
|
||||||
try:
|
try:
|
||||||
await sconn.conn.write(msg)
|
await sconn.write(msg)
|
||||||
except AsyncStreamWriteError:
|
except AsyncStreamWriteError:
|
||||||
trace "Could not write to connection"
|
trace "Could not write to connection"
|
||||||
|
|
||||||
proc newSecureConnection*(conn: Connection, hash: string, cipher: string,
|
proc newSecureConnection*(conn: Connection,
|
||||||
|
hash: string,
|
||||||
|
cipher: string,
|
||||||
secrets: Secret,
|
secrets: Secret,
|
||||||
order: int): SecureConnection =
|
order: int,
|
||||||
|
peerId: PeerID): SecureConnection =
|
||||||
## Create new secure connection, using specified hash algorithm ``hash``,
|
## Create new secure connection, using specified hash algorithm ``hash``,
|
||||||
## cipher algorithm ``cipher``, stretched keys ``secrets`` and order
|
## cipher algorithm ``cipher``, stretched keys ``secrets`` and order
|
||||||
## ``order``.
|
## ``order``.
|
||||||
new result
|
new result
|
||||||
|
|
||||||
|
result.stream = conn
|
||||||
|
result.closeEvent = newAsyncEvent()
|
||||||
|
|
||||||
let i0 = if order < 0: 1 else: 0
|
let i0 = if order < 0: 1 else: 0
|
||||||
let i1 = if order < 0: 0 else: 1
|
let i1 = if order < 0: 0 else: 1
|
||||||
|
|
||||||
result.conn = conn
|
|
||||||
trace "Writer credentials", mackey = toHex(secrets.macOpenArray(i0)),
|
trace "Writer credentials", mackey = toHex(secrets.macOpenArray(i0)),
|
||||||
enckey = toHex(secrets.keyOpenArray(i0)),
|
enckey = toHex(secrets.keyOpenArray(i0)),
|
||||||
iv = toHex(secrets.ivOpenArray(i0))
|
iv = toHex(secrets.ivOpenArray(i0))
|
||||||
|
@ -241,6 +248,8 @@ proc newSecureConnection*(conn: Connection, hash: string, cipher: string,
|
||||||
result.readerCoder.init(cipher, secrets.keyOpenArray(i1),
|
result.readerCoder.init(cipher, secrets.keyOpenArray(i1),
|
||||||
secrets.ivOpenArray(i1))
|
secrets.ivOpenArray(i1))
|
||||||
|
|
||||||
|
result.peerInfo.peerId = some(peerId)
|
||||||
|
|
||||||
proc transactMessage(conn: Connection,
|
proc transactMessage(conn: Connection,
|
||||||
msg: seq[byte]): Future[seq[byte]] {.async.} =
|
msg: seq[byte]): Future[seq[byte]] {.async.} =
|
||||||
var buf = newSeq[byte](4)
|
var buf = newSeq[byte](4)
|
||||||
|
@ -281,7 +290,6 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.}
|
||||||
remoteHashes: string
|
remoteHashes: string
|
||||||
remotePeerId: PeerID
|
remotePeerId: PeerID
|
||||||
localPeerId: PeerID
|
localPeerId: PeerID
|
||||||
ekey: PrivateKey
|
|
||||||
localBytesPubkey = s.localPublicKey.getBytes()
|
localBytesPubkey = s.localPublicKey.getBytes()
|
||||||
|
|
||||||
if randomBytes(localNonce) != SecioNonceSize:
|
if randomBytes(localNonce) != SecioNonceSize:
|
||||||
|
@ -388,7 +396,8 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.}
|
||||||
|
|
||||||
# Perform Nonce exchange over encrypted channel.
|
# Perform Nonce exchange over encrypted channel.
|
||||||
|
|
||||||
result = newSecureConnection(conn, hash, cipher, keys, order)
|
result = newSecureConnection(conn, hash, cipher, keys, order, remotePeerId)
|
||||||
|
|
||||||
await result.writeMessage(remoteNonce)
|
await result.writeMessage(remoteNonce)
|
||||||
var res = await result.readMessage()
|
var res = await result.readMessage()
|
||||||
|
|
||||||
|
@ -400,17 +409,21 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.}
|
||||||
trace "Secure handshake succeeded"
|
trace "Secure handshake succeeded"
|
||||||
|
|
||||||
proc readLoop(sconn: SecureConnection, stream: BufferStream) {.async.} =
|
proc readLoop(sconn: SecureConnection, stream: BufferStream) {.async.} =
|
||||||
while not sconn.conn.closed:
|
try:
|
||||||
try:
|
while not sconn.closed:
|
||||||
let msg = await sconn.readMessage()
|
let msg = await sconn.readMessage()
|
||||||
await stream.pushTo(msg)
|
if msg.len > 0:
|
||||||
except CatchableError as exc:
|
await stream.pushTo(msg)
|
||||||
trace "exception in secio", exc = exc.msg
|
|
||||||
return
|
# tight loop, give a chance for other
|
||||||
finally:
|
# stuff to run as well
|
||||||
trace "ending secio readLoop"
|
await sleepAsync(1.millis)
|
||||||
|
except CatchableError as exc:
|
||||||
|
trace "exception occured", exc = exc.msg
|
||||||
|
finally:
|
||||||
|
trace "ending secio readLoop", isclosed = sconn.closed()
|
||||||
|
|
||||||
proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async.} =
|
proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
||||||
var sconn = await s.handshake(conn)
|
var sconn = await s.handshake(conn)
|
||||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} =
|
proc writeHandler(data: seq[byte]) {.async, gcsafe.} =
|
||||||
trace "sending encrypted bytes", bytes = data.toHex()
|
trace "sending encrypted bytes", bytes = data.toHex()
|
||||||
|
@ -419,7 +432,13 @@ proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async.} =
|
||||||
var stream = newBufferStream(writeHandler)
|
var stream = newBufferStream(writeHandler)
|
||||||
asyncCheck readLoop(sconn, stream)
|
asyncCheck readLoop(sconn, stream)
|
||||||
var secured = newConnection(stream)
|
var secured = newConnection(stream)
|
||||||
secured.peerInfo = sconn.conn.peerInfo
|
secured.closeEvent.wait()
|
||||||
|
.addCallback(proc(udata: pointer) =
|
||||||
|
trace "wrapped connection closed, closing upstream"
|
||||||
|
if not sconn.closed:
|
||||||
|
asyncCheck sconn.close()
|
||||||
|
)
|
||||||
|
secured.peerInfo.peerId = sconn.peerInfo.peerId
|
||||||
result = secured
|
result = secured
|
||||||
|
|
||||||
method init(s: Secio) {.gcsafe.} =
|
method init(s: Secio) {.gcsafe.} =
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
## those terms.
|
## those terms.
|
||||||
|
|
||||||
## This module implements an asynchronous buffer stream
|
## This module implements an asynchronous buffer stream
|
||||||
## which emulates physical async IO.
|
## which emulates physical async IO.
|
||||||
##
|
##
|
||||||
## The stream is based on the standard library's `Deque`,
|
## The stream is based on the standard library's `Deque`,
|
||||||
## which is itself based on a ring buffer.
|
## which is itself based on a ring buffer.
|
||||||
|
@ -25,12 +25,12 @@
|
||||||
## ordered and asynchronous. Reads are queued up in order
|
## ordered and asynchronous. Reads are queued up in order
|
||||||
## and are suspended when not enough data available. This
|
## and are suspended when not enough data available. This
|
||||||
## allows preserving backpressure while maintaining full
|
## allows preserving backpressure while maintaining full
|
||||||
## asynchrony. Both writting to the internal buffer with
|
## asynchrony. Both writting to the internal buffer with
|
||||||
## ``pushTo`` as well as reading with ``read*` methods,
|
## ``pushTo`` as well as reading with ``read*` methods,
|
||||||
## will suspend until either the amount of elements in the
|
## will suspend until either the amount of elements in the
|
||||||
## buffer goes below ``maxSize`` or more data becomes available.
|
## buffer goes below ``maxSize`` or more data becomes available.
|
||||||
|
|
||||||
import deques, tables, sequtils, math
|
import deques, math
|
||||||
import chronos
|
import chronos
|
||||||
import ../stream/lpstream
|
import ../stream/lpstream
|
||||||
|
|
||||||
|
@ -38,33 +38,49 @@ const DefaultBufferSize* = 1024
|
||||||
|
|
||||||
type
|
type
|
||||||
# TODO: figure out how to make this generic to avoid casts
|
# TODO: figure out how to make this generic to avoid casts
|
||||||
WriteHandler* = proc (data: seq[byte]): Future[void] {.gcsafe.}
|
WriteHandler* = proc (data: seq[byte]): Future[void]
|
||||||
|
|
||||||
BufferStream* = ref object of LPStream
|
BufferStream* = ref object of LPStream
|
||||||
maxSize*: int # buffer's max size in bytes
|
maxSize*: int # buffer's max size in bytes
|
||||||
readBuf: Deque[byte] # a deque is based on a ring buffer
|
readBuf: Deque[byte] # this is a ring buffer based dequeue, this makes it perfect as the backing store here
|
||||||
readReqs: Deque[Future[void]] # use dequeue to fire reads in order
|
readReqs: Deque[Future[void]] # use dequeue to fire reads in order
|
||||||
dataReadEvent: AsyncEvent
|
dataReadEvent: AsyncEvent
|
||||||
writeHandler*: WriteHandler
|
writeHandler*: WriteHandler
|
||||||
|
lock: AsyncLock
|
||||||
|
isPiped: bool
|
||||||
|
|
||||||
proc requestReadBytes(s: BufferStream): Future[void] =
|
AlreadyPipedError* = object of CatchableError
|
||||||
|
NotWritableError* = object of CatchableError
|
||||||
|
|
||||||
|
proc newAlreadyPipedError*(): ref Exception {.inline.} =
|
||||||
|
result = newException(AlreadyPipedError, "stream already piped")
|
||||||
|
|
||||||
|
proc newNotWritableError*(): ref Exception {.inline.} =
|
||||||
|
result = newException(NotWritableError, "stream is not writable")
|
||||||
|
|
||||||
|
proc requestReadBytes(s: BufferStream): Future[void] =
|
||||||
## create a future that will complete when more
|
## create a future that will complete when more
|
||||||
## data becomes available in the read buffer
|
## data becomes available in the read buffer
|
||||||
result = newFuture[void]()
|
result = newFuture[void]()
|
||||||
s.readReqs.addLast(result)
|
s.readReqs.addLast(result)
|
||||||
|
|
||||||
proc initBufferStream*(s: BufferStream, handler: WriteHandler, size: int = DefaultBufferSize) =
|
proc initBufferStream*(s: BufferStream,
|
||||||
|
handler: WriteHandler = nil,
|
||||||
|
size: int = DefaultBufferSize) =
|
||||||
s.maxSize = if isPowerOfTwo(size): size else: nextPowerOfTwo(size)
|
s.maxSize = if isPowerOfTwo(size): size else: nextPowerOfTwo(size)
|
||||||
s.readBuf = initDeque[byte](s.maxSize)
|
s.readBuf = initDeque[byte](s.maxSize)
|
||||||
s.readReqs = initDeque[Future[void]]()
|
s.readReqs = initDeque[Future[void]]()
|
||||||
s.dataReadEvent = newAsyncEvent()
|
s.dataReadEvent = newAsyncEvent()
|
||||||
|
s.lock = newAsyncLock()
|
||||||
s.writeHandler = handler
|
s.writeHandler = handler
|
||||||
|
s.closeEvent = newAsyncEvent()
|
||||||
|
|
||||||
proc newBufferStream*(handler: WriteHandler, size: int = DefaultBufferSize): BufferStream =
|
proc newBufferStream*(handler: WriteHandler = nil,
|
||||||
|
size: int = DefaultBufferSize): BufferStream =
|
||||||
new result
|
new result
|
||||||
result.initBufferStream(handler, size)
|
result.initBufferStream(handler, size)
|
||||||
|
|
||||||
proc popFirst*(s: BufferStream): byte =
|
proc popFirst*(s: BufferStream): byte =
|
||||||
result = s.readBuf.popFirst()
|
result = s.readBuf.popFirst()
|
||||||
s.dataReadEvent.fire()
|
s.dataReadEvent.fire()
|
||||||
|
|
||||||
|
@ -78,15 +94,24 @@ proc shrink(s: BufferStream, fromFirst = 0, fromLast = 0) =
|
||||||
|
|
||||||
proc len*(s: BufferStream): int = s.readBuf.len
|
proc len*(s: BufferStream): int = s.readBuf.len
|
||||||
|
|
||||||
proc pushTo*(s: BufferStream, data: seq[byte]) {.async, gcsafe.} =
|
proc pushTo*(s: BufferStream, data: seq[byte]) {.async.} =
|
||||||
## Write bytes to internal read buffer, use this to fill up the
|
## Write bytes to internal read buffer, use this to fill up the
|
||||||
## buffer with data.
|
## buffer with data.
|
||||||
##
|
##
|
||||||
## This method is async and will wait until all data has been
|
## This method is async and will wait until all data has been
|
||||||
## written to the internal buffer; this is done so that backpressure
|
## written to the internal buffer; this is done so that backpressure
|
||||||
## is preserved.
|
## is preserved.
|
||||||
|
##
|
||||||
|
|
||||||
|
await s.lock.acquire()
|
||||||
var index = 0
|
var index = 0
|
||||||
while true:
|
while true:
|
||||||
|
|
||||||
|
# give readers a chance free up the buffer
|
||||||
|
# it it's full.
|
||||||
|
if s.readBuf.len >= s.maxSize:
|
||||||
|
await sleepAsync(10.millis)
|
||||||
|
|
||||||
while index < data.len and s.readBuf.len < s.maxSize:
|
while index < data.len and s.readBuf.len < s.maxSize:
|
||||||
s.readBuf.addLast(data[index])
|
s.readBuf.addLast(data[index])
|
||||||
inc(index)
|
inc(index)
|
||||||
|
@ -94,18 +119,20 @@ proc pushTo*(s: BufferStream, data: seq[byte]) {.async, gcsafe.} =
|
||||||
# resolve the next queued read request
|
# resolve the next queued read request
|
||||||
if s.readReqs.len > 0:
|
if s.readReqs.len > 0:
|
||||||
s.readReqs.popFirst().complete()
|
s.readReqs.popFirst().complete()
|
||||||
|
|
||||||
if index >= data.len:
|
if index >= data.len:
|
||||||
break
|
break
|
||||||
|
|
||||||
# if we couldn't transfer all the data to the
|
# if we couldn't transfer all the data to the
|
||||||
# internal buf wait on a read event
|
# internal buf wait on a read event
|
||||||
await s.dataReadEvent.wait()
|
await s.dataReadEvent.wait()
|
||||||
|
s.lock.release()
|
||||||
|
|
||||||
method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async, gcsafe.} =
|
method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async.} =
|
||||||
## Read all bytes (n <= 0) or exactly `n` bytes from buffer
|
## Read all bytes (n <= 0) or exactly `n` bytes from buffer
|
||||||
##
|
##
|
||||||
## This procedure allocates buffer seq[byte] and return it as result.
|
## This procedure allocates buffer seq[byte] and return it as result.
|
||||||
|
##
|
||||||
var size = if n > 0: n else: s.readBuf.len()
|
var size = if n > 0: n else: s.readBuf.len()
|
||||||
var index = 0
|
var index = 0
|
||||||
while index < size:
|
while index < size:
|
||||||
|
@ -116,25 +143,26 @@ method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async, gcsafe.} =
|
||||||
if index < size:
|
if index < size:
|
||||||
await s.requestReadBytes()
|
await s.requestReadBytes()
|
||||||
|
|
||||||
method readExactly*(s: BufferStream,
|
method readExactly*(s: BufferStream,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
nbytes: int):
|
nbytes: int):
|
||||||
Future[void] {.async, gcsafe.} =
|
Future[void] {.async.} =
|
||||||
## Read exactly ``nbytes`` bytes from read-only stream ``rstream`` and store
|
## Read exactly ``nbytes`` bytes from read-only stream ``rstream`` and store
|
||||||
## it to ``pbytes``.
|
## it to ``pbytes``.
|
||||||
##
|
##
|
||||||
## If EOF is received and ``nbytes`` is not yet read, the procedure
|
## If EOF is received and ``nbytes`` is not yet read, the procedure
|
||||||
## will raise ``LPStreamIncompleteError``.
|
## will raise ``LPStreamIncompleteError``.
|
||||||
let buff = await s.read(nbytes)
|
##
|
||||||
|
var buff = await s.read(nbytes)
|
||||||
if nbytes > buff.len():
|
if nbytes > buff.len():
|
||||||
raise newLPStreamIncompleteError()
|
raise newLPStreamIncompleteError()
|
||||||
|
|
||||||
copyMem(pbytes, unsafeAddr buff[0], nbytes)
|
copyMem(pbytes, addr buff[0], nbytes)
|
||||||
|
|
||||||
method readLine*(s: BufferStream,
|
method readLine*(s: BufferStream,
|
||||||
limit = 0,
|
limit = 0,
|
||||||
sep = "\r\n"):
|
sep = "\r\n"):
|
||||||
Future[string] {.async, gcsafe.} =
|
Future[string] {.async.} =
|
||||||
## Read one line from read-only stream ``rstream``, where ``"line"`` is a
|
## Read one line from read-only stream ``rstream``, where ``"line"`` is a
|
||||||
## sequence of bytes ending with ``sep`` (default is ``"\r\n"``).
|
## sequence of bytes ending with ``sep`` (default is ``"\r\n"``).
|
||||||
##
|
##
|
||||||
|
@ -146,6 +174,7 @@ method readLine*(s: BufferStream,
|
||||||
##
|
##
|
||||||
## If ``limit`` more then 0, then result string will be limited to ``limit``
|
## If ``limit`` more then 0, then result string will be limited to ``limit``
|
||||||
## bytes.
|
## bytes.
|
||||||
|
##
|
||||||
result = ""
|
result = ""
|
||||||
var lim = if limit <= 0: -1 else: limit
|
var lim = if limit <= 0: -1 else: limit
|
||||||
var state = 0
|
var state = 0
|
||||||
|
@ -170,14 +199,15 @@ method readLine*(s: BufferStream,
|
||||||
method readOnce*(s: BufferStream,
|
method readOnce*(s: BufferStream,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
nbytes: int):
|
nbytes: int):
|
||||||
Future[int] {.async, gcsafe.} =
|
Future[int] {.async.} =
|
||||||
## Perform one read operation on read-only stream ``rstream``.
|
## Perform one read operation on read-only stream ``rstream``.
|
||||||
##
|
##
|
||||||
## If internal buffer is not empty, ``nbytes`` bytes will be transferred from
|
## If internal buffer is not empty, ``nbytes`` bytes will be transferred from
|
||||||
## internal buffer, otherwise it will wait until some bytes will be received.
|
## internal buffer, otherwise it will wait until some bytes will be received.
|
||||||
|
##
|
||||||
if s.readBuf.len == 0:
|
if s.readBuf.len == 0:
|
||||||
await s.requestReadBytes()
|
await s.requestReadBytes()
|
||||||
|
|
||||||
var len = if nbytes > s.readBuf.len: s.readBuf.len else: nbytes
|
var len = if nbytes > s.readBuf.len: s.readBuf.len else: nbytes
|
||||||
await s.readExactly(pbytes, len)
|
await s.readExactly(pbytes, len)
|
||||||
result = len
|
result = len
|
||||||
|
@ -186,7 +216,7 @@ method readUntil*(s: BufferStream,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
nbytes: int,
|
nbytes: int,
|
||||||
sep: seq[byte]):
|
sep: seq[byte]):
|
||||||
Future[int] {.async, gcsafe.} =
|
Future[int] {.async.} =
|
||||||
## Read data from the read-only stream ``rstream`` until separator ``sep`` is
|
## Read data from the read-only stream ``rstream`` until separator ``sep`` is
|
||||||
## found.
|
## found.
|
||||||
##
|
##
|
||||||
|
@ -200,6 +230,7 @@ method readUntil*(s: BufferStream,
|
||||||
## will raise ``LPStreamLimitError``.
|
## will raise ``LPStreamLimitError``.
|
||||||
##
|
##
|
||||||
## Procedure returns actual number of bytes read.
|
## Procedure returns actual number of bytes read.
|
||||||
|
##
|
||||||
var
|
var
|
||||||
dest = cast[ptr UncheckedArray[byte]](pbytes)
|
dest = cast[ptr UncheckedArray[byte]](pbytes)
|
||||||
state = 0
|
state = 0
|
||||||
|
@ -231,22 +262,22 @@ method readUntil*(s: BufferStream,
|
||||||
else:
|
else:
|
||||||
s.shrink(datalen)
|
s.shrink(datalen)
|
||||||
|
|
||||||
method write*(s: BufferStream,
|
method write*(s: BufferStream,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
nbytes: int): Future[void]
|
nbytes: int): Future[void] =
|
||||||
{.gcsafe.} =
|
|
||||||
## Consume (discard) all bytes (n <= 0) or ``n`` bytes from read-only stream
|
## Consume (discard) all bytes (n <= 0) or ``n`` bytes from read-only stream
|
||||||
## ``rstream``.
|
## ``rstream``.
|
||||||
##
|
##
|
||||||
## Return number of bytes actually consumed (discarded).
|
## Return number of bytes actually consumed (discarded).
|
||||||
|
##
|
||||||
var buf: seq[byte] = newSeq[byte](nbytes)
|
var buf: seq[byte] = newSeq[byte](nbytes)
|
||||||
copyMem(addr buf[0], pbytes, nbytes)
|
copyMem(addr buf[0], pbytes, nbytes)
|
||||||
result = s.writeHandler(buf)
|
if not isNil(s.writeHandler):
|
||||||
|
result = s.writeHandler(buf)
|
||||||
|
|
||||||
method write*(s: BufferStream,
|
method write*(s: BufferStream,
|
||||||
msg: string,
|
msg: string,
|
||||||
msglen = -1): Future[void]
|
msglen = -1): Future[void] =
|
||||||
{.gcsafe.} =
|
|
||||||
## Write string ``sbytes`` of length ``msglen`` to writer stream ``wstream``.
|
## Write string ``sbytes`` of length ``msglen`` to writer stream ``wstream``.
|
||||||
##
|
##
|
||||||
## String ``sbytes`` must not be zero-length.
|
## String ``sbytes`` must not be zero-length.
|
||||||
|
@ -254,14 +285,15 @@ method write*(s: BufferStream,
|
||||||
## If ``msglen < 0`` whole string ``sbytes`` will be writen to stream.
|
## If ``msglen < 0`` whole string ``sbytes`` will be writen to stream.
|
||||||
## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to
|
## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to
|
||||||
## stream.
|
## stream.
|
||||||
|
##
|
||||||
var buf = ""
|
var buf = ""
|
||||||
shallowCopy(buf, if msglen > 0: msg[0..<msglen] else: msg)
|
shallowCopy(buf, if msglen > 0: msg[0..<msglen] else: msg)
|
||||||
result = s.writeHandler(cast[seq[byte]](buf))
|
if not isNil(s.writeHandler):
|
||||||
|
result = s.writeHandler(cast[seq[byte]](buf))
|
||||||
|
|
||||||
method write*(s: BufferStream,
|
method write*(s: BufferStream,
|
||||||
msg: seq[byte],
|
msg: seq[byte],
|
||||||
msglen = -1): Future[void]
|
msglen = -1): Future[void] =
|
||||||
{.gcsafe.} =
|
|
||||||
## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer
|
## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer
|
||||||
## stream ``wstream``.
|
## stream ``wstream``.
|
||||||
##
|
##
|
||||||
|
@ -270,13 +302,56 @@ method write*(s: BufferStream,
|
||||||
## If ``msglen < 0`` whole sequence ``sbytes`` will be writen to stream.
|
## If ``msglen < 0`` whole sequence ``sbytes`` will be writen to stream.
|
||||||
## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to
|
## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to
|
||||||
## stream.
|
## stream.
|
||||||
|
##
|
||||||
var buf: seq[byte]
|
var buf: seq[byte]
|
||||||
shallowCopy(buf, if msglen > 0: msg[0..<msglen] else: msg)
|
shallowCopy(buf, if msglen > 0: msg[0..<msglen] else: msg)
|
||||||
result = s.writeHandler(buf)
|
if not isNil(s.writeHandler):
|
||||||
|
result = s.writeHandler(buf)
|
||||||
|
|
||||||
method close*(s: BufferStream) {.async, gcsafe.} =
|
proc pipe*(s: BufferStream,
|
||||||
|
target: BufferStream): BufferStream =
|
||||||
|
## pipe the write end of this stream to
|
||||||
|
## be the source of the target stream
|
||||||
|
##
|
||||||
|
## Note that this only works with the LPStream
|
||||||
|
## interface methods `read*` and `write` are
|
||||||
|
## piped.
|
||||||
|
##
|
||||||
|
if s.isPiped:
|
||||||
|
raise newAlreadyPipedError()
|
||||||
|
|
||||||
|
s.isPiped = true
|
||||||
|
let oldHandler = target.writeHandler
|
||||||
|
proc handler(data: seq[byte]) {.async, closure.} =
|
||||||
|
if not isNil(oldHandler):
|
||||||
|
await oldHandler(data)
|
||||||
|
|
||||||
|
# if we're piping to self,
|
||||||
|
# then add the data to the
|
||||||
|
# buffer directly and fire
|
||||||
|
# the read event
|
||||||
|
if s == target:
|
||||||
|
for b in data:
|
||||||
|
s.readBuf.addLast(b)
|
||||||
|
|
||||||
|
# notify main loop of available
|
||||||
|
# data
|
||||||
|
s.dataReadEvent.fire()
|
||||||
|
else:
|
||||||
|
await target.pushTo(data)
|
||||||
|
|
||||||
|
s.writeHandler = handler
|
||||||
|
result = target
|
||||||
|
|
||||||
|
proc `|`*(s: BufferStream, target: BufferStream): BufferStream =
|
||||||
|
## pipe operator to make piping less verbose
|
||||||
|
pipe(s, target)
|
||||||
|
|
||||||
|
method close*(s: BufferStream) {.async.} =
|
||||||
## close the stream and clear the buffer
|
## close the stream and clear the buffer
|
||||||
for r in s.readReqs:
|
for r in s.readReqs:
|
||||||
r.cancel()
|
r.cancel()
|
||||||
|
s.dataReadEvent.fire()
|
||||||
s.readBuf.clear()
|
s.readBuf.clear()
|
||||||
s.closed = true
|
s.closeEvent.fire()
|
||||||
|
s.isClosed = true
|
||||||
|
|
|
@ -26,40 +26,63 @@ proc newChronosStream*(server: StreamServer,
|
||||||
result.client = client
|
result.client = client
|
||||||
result.reader = newAsyncStreamReader(client)
|
result.reader = newAsyncStreamReader(client)
|
||||||
result.writer = newAsyncStreamWriter(client)
|
result.writer = newAsyncStreamWriter(client)
|
||||||
result.closed = false
|
result.closeEvent = newAsyncEvent()
|
||||||
|
|
||||||
|
method read*(s: ChronosStream, n = -1): Future[seq[byte]] {.async.} =
|
||||||
|
if s.reader.atEof:
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
|
||||||
method read*(s: ChronosStream, n = -1): Future[seq[byte]] {.async, gcsafe.} =
|
|
||||||
try:
|
try:
|
||||||
result = await s.reader.read(n)
|
result = await s.reader.read(n)
|
||||||
except AsyncStreamReadError as exc:
|
except AsyncStreamReadError as exc:
|
||||||
raise newLPStreamReadError(exc.par)
|
raise newLPStreamReadError(exc.par)
|
||||||
|
except AsyncStreamIncorrectError as exc:
|
||||||
|
raise newLPStreamIncorrectError(exc.msg)
|
||||||
|
|
||||||
method readExactly*(s: ChronosStream,
|
method readExactly*(s: ChronosStream,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
nbytes: int): Future[void] {.async, gcsafe.} =
|
nbytes: int): Future[void] {.async.} =
|
||||||
|
if s.reader.atEof:
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await s.reader.readExactly(pbytes, nbytes)
|
await s.reader.readExactly(pbytes, nbytes)
|
||||||
except AsyncStreamIncompleteError:
|
except AsyncStreamIncompleteError:
|
||||||
raise newLPStreamIncompleteError()
|
raise newLPStreamIncompleteError()
|
||||||
except AsyncStreamReadError as exc:
|
except AsyncStreamReadError as exc:
|
||||||
raise newLPStreamReadError(exc.par)
|
raise newLPStreamReadError(exc.par)
|
||||||
|
except AsyncStreamIncorrectError as exc:
|
||||||
|
raise newLPStreamIncorrectError(exc.msg)
|
||||||
|
|
||||||
|
method readLine*(s: ChronosStream, limit = 0, sep = "\r\n"): Future[string] {.async.} =
|
||||||
|
if s.reader.atEof:
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
|
||||||
method readLine*(s: ChronosStream, limit = 0, sep = "\r\n"): Future[string] {.async, gcsafe.} =
|
|
||||||
try:
|
try:
|
||||||
result = await s.reader.readLine(limit, sep)
|
result = await s.reader.readLine(limit, sep)
|
||||||
except AsyncStreamReadError as exc:
|
except AsyncStreamReadError as exc:
|
||||||
raise newLPStreamReadError(exc.par)
|
raise newLPStreamReadError(exc.par)
|
||||||
|
except AsyncStreamIncorrectError as exc:
|
||||||
|
raise newLPStreamIncorrectError(exc.msg)
|
||||||
|
|
||||||
|
method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} =
|
||||||
|
if s.reader.atEof:
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
|
||||||
method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async, gcsafe.} =
|
|
||||||
try:
|
try:
|
||||||
result = await s.reader.readOnce(pbytes, nbytes)
|
result = await s.reader.readOnce(pbytes, nbytes)
|
||||||
except AsyncStreamReadError as exc:
|
except AsyncStreamReadError as exc:
|
||||||
raise newLPStreamReadError(exc.par)
|
raise newLPStreamReadError(exc.par)
|
||||||
|
except AsyncStreamIncorrectError as exc:
|
||||||
|
raise newLPStreamIncorrectError(exc.msg)
|
||||||
|
|
||||||
method readUntil*(s: ChronosStream,
|
method readUntil*(s: ChronosStream,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
nbytes: int,
|
nbytes: int,
|
||||||
sep: seq[byte]): Future[int] {.async, gcsafe.} =
|
sep: seq[byte]): Future[int] {.async.} =
|
||||||
|
if s.reader.atEof:
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await s.reader.readUntil(pbytes, nbytes, sep)
|
result = await s.reader.readUntil(pbytes, nbytes, sep)
|
||||||
except AsyncStreamIncompleteError:
|
except AsyncStreamIncompleteError:
|
||||||
|
@ -68,36 +91,62 @@ method readUntil*(s: ChronosStream,
|
||||||
raise newLPStreamLimitError()
|
raise newLPStreamLimitError()
|
||||||
except LPStreamReadError as exc:
|
except LPStreamReadError as exc:
|
||||||
raise newLPStreamReadError(exc.par)
|
raise newLPStreamReadError(exc.par)
|
||||||
|
except AsyncStreamIncorrectError as exc:
|
||||||
|
raise newLPStreamIncorrectError(exc.msg)
|
||||||
|
|
||||||
|
method write*(s: ChronosStream, pbytes: pointer, nbytes: int) {.async.} =
|
||||||
|
if s.writer.atEof:
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
|
||||||
method write*(s: ChronosStream, pbytes: pointer, nbytes: int) {.async, gcsafe.} =
|
|
||||||
try:
|
try:
|
||||||
await s.writer.write(pbytes, nbytes)
|
await s.writer.write(pbytes, nbytes)
|
||||||
except AsyncStreamWriteError as exc:
|
except AsyncStreamWriteError as exc:
|
||||||
raise newLPStreamWriteError(exc.par)
|
raise newLPStreamWriteError(exc.par)
|
||||||
except AsyncStreamIncompleteError:
|
except AsyncStreamIncompleteError:
|
||||||
raise newLPStreamIncompleteError()
|
raise newLPStreamIncompleteError()
|
||||||
|
except AsyncStreamIncorrectError as exc:
|
||||||
|
raise newLPStreamIncorrectError(exc.msg)
|
||||||
|
|
||||||
|
method write*(s: ChronosStream, msg: string, msglen = -1) {.async.} =
|
||||||
|
if s.writer.atEof:
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
|
||||||
method write*(s: ChronosStream, msg: string, msglen = -1) {.async, gcsafe.} =
|
|
||||||
try:
|
try:
|
||||||
await s.writer.write(msg, msglen)
|
await s.writer.write(msg, msglen)
|
||||||
except AsyncStreamWriteError as exc:
|
except AsyncStreamWriteError as exc:
|
||||||
raise newLPStreamWriteError(exc.par)
|
raise newLPStreamWriteError(exc.par)
|
||||||
except AsyncStreamIncompleteError:
|
except AsyncStreamIncompleteError:
|
||||||
raise newLPStreamIncompleteError()
|
raise newLPStreamIncompleteError()
|
||||||
|
except AsyncStreamIncorrectError as exc:
|
||||||
|
raise newLPStreamIncorrectError(exc.msg)
|
||||||
|
|
||||||
|
method write*(s: ChronosStream, msg: seq[byte], msglen = -1) {.async.} =
|
||||||
|
if s.writer.atEof:
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
|
||||||
method write*(s: ChronosStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
|
|
||||||
try:
|
try:
|
||||||
await s.writer.write(msg, msglen)
|
await s.writer.write(msg, msglen)
|
||||||
except AsyncStreamWriteError as exc:
|
except AsyncStreamWriteError as exc:
|
||||||
raise newLPStreamWriteError(exc.par)
|
raise newLPStreamWriteError(exc.par)
|
||||||
except AsyncStreamIncompleteError:
|
except AsyncStreamIncompleteError:
|
||||||
raise newLPStreamIncompleteError()
|
raise newLPStreamIncompleteError()
|
||||||
|
except AsyncStreamIncorrectError as exc:
|
||||||
|
raise newLPStreamIncorrectError(exc.msg)
|
||||||
|
|
||||||
method close*(s: ChronosStream) {.async, gcsafe.} =
|
method closed*(s: ChronosStream): bool {.inline.} =
|
||||||
|
# TODO: we might only need to check for reader's EOF
|
||||||
|
result = s.reader.atEof()
|
||||||
|
|
||||||
|
method close*(s: ChronosStream) {.async.} =
|
||||||
if not s.closed:
|
if not s.closed:
|
||||||
trace "shutting down server", address = $s.client.remoteAddress()
|
trace "shutting chronos stream", address = $s.client.remoteAddress()
|
||||||
await s.writer.finish()
|
if not s.writer.closed():
|
||||||
await s.writer.closeWait()
|
await s.writer.closeWait()
|
||||||
await s.reader.closeWait()
|
|
||||||
await s.client.closeWait()
|
if not s.reader.closed():
|
||||||
s.closed = true
|
await s.reader.closeWait()
|
||||||
|
|
||||||
|
if not s.client.closed():
|
||||||
|
await s.client.closeWait()
|
||||||
|
|
||||||
|
s.closeEvent.fire()
|
||||||
|
|
|
@ -11,7 +11,8 @@ import chronos
|
||||||
|
|
||||||
type
|
type
|
||||||
LPStream* = ref object of RootObj
|
LPStream* = ref object of RootObj
|
||||||
closed*: bool
|
isClosed*: bool
|
||||||
|
closeEvent*: AsyncEvent
|
||||||
|
|
||||||
LPStreamError* = object of CatchableError
|
LPStreamError* = object of CatchableError
|
||||||
LPStreamIncompleteError* = object of LPStreamError
|
LPStreamIncompleteError* = object of LPStreamError
|
||||||
|
@ -47,40 +48,43 @@ proc newLPStreamIncorrectError*(m: string): ref Exception {.inline.} =
|
||||||
proc newLPStreamClosedError*(): ref Exception {.inline.} =
|
proc newLPStreamClosedError*(): ref Exception {.inline.} =
|
||||||
result = newException(LPStreamClosedError, "Stream closed!")
|
result = newException(LPStreamClosedError, "Stream closed!")
|
||||||
|
|
||||||
|
method closed*(s: LPStream): bool {.base, inline.} =
|
||||||
|
s.isClosed
|
||||||
|
|
||||||
method read*(s: LPStream, n = -1): Future[seq[byte]]
|
method read*(s: LPStream, n = -1): Future[seq[byte]]
|
||||||
{.base, async, gcsafe.} =
|
{.base, async.} =
|
||||||
doAssert(false, "not implemented!")
|
doAssert(false, "not implemented!")
|
||||||
|
|
||||||
method readExactly*(s: LPStream, pbytes: pointer, nbytes: int): Future[void]
|
method readExactly*(s: LPStream, pbytes: pointer, nbytes: int): Future[void]
|
||||||
{.base, async, gcsafe.} =
|
{.base, async.} =
|
||||||
doAssert(false, "not implemented!")
|
doAssert(false, "not implemented!")
|
||||||
|
|
||||||
method readLine*(s: LPStream, limit = 0, sep = "\r\n"): Future[string]
|
method readLine*(s: LPStream, limit = 0, sep = "\r\n"): Future[string]
|
||||||
{.base, async, gcsafe.} =
|
{.base, async.} =
|
||||||
doAssert(false, "not implemented!")
|
doAssert(false, "not implemented!")
|
||||||
|
|
||||||
method readOnce*(s: LPStream, pbytes: pointer, nbytes: int): Future[int]
|
method readOnce*(s: LPStream, pbytes: pointer, nbytes: int): Future[int]
|
||||||
{.base, async, gcsafe.} =
|
{.base, async.} =
|
||||||
doAssert(false, "not implemented!")
|
doAssert(false, "not implemented!")
|
||||||
|
|
||||||
method readUntil*(s: LPStream,
|
method readUntil*(s: LPStream,
|
||||||
pbytes: pointer, nbytes: int,
|
pbytes: pointer, nbytes: int,
|
||||||
sep: seq[byte]): Future[int]
|
sep: seq[byte]): Future[int]
|
||||||
{.base, async, gcsafe.} =
|
{.base, async.} =
|
||||||
doAssert(false, "not implemented!")
|
doAssert(false, "not implemented!")
|
||||||
|
|
||||||
method write*(s: LPStream, pbytes: pointer, nbytes: int)
|
method write*(s: LPStream, pbytes: pointer, nbytes: int)
|
||||||
{.base, async, gcsafe.} =
|
{.base, async.} =
|
||||||
doAssert(false, "not implemented!")
|
doAssert(false, "not implemented!")
|
||||||
|
|
||||||
method write*(s: LPStream, msg: string, msglen = -1)
|
method write*(s: LPStream, msg: string, msglen = -1)
|
||||||
{.base, async, gcsafe.} =
|
{.base, async.} =
|
||||||
doAssert(false, "not implemented!")
|
doAssert(false, "not implemented!")
|
||||||
|
|
||||||
method write*(s: LPStream, msg: seq[byte], msglen = -1)
|
method write*(s: LPStream, msg: seq[byte], msglen = -1)
|
||||||
{.base, async, gcsafe.} =
|
{.base, async.} =
|
||||||
doAssert(false, "not implemented!")
|
doAssert(false, "not implemented!")
|
||||||
|
|
||||||
method close*(s: LPStream)
|
method close*(s: LPStream)
|
||||||
{.base, async, gcsafe.} =
|
{.base, async.} =
|
||||||
doAssert(false, "not implemented!")
|
doAssert(false, "not implemented!")
|
||||||
|
|
|
@ -11,13 +11,11 @@ import tables, sequtils, options, strformat
|
||||||
import chronos, chronicles
|
import chronos, chronicles
|
||||||
import connection,
|
import connection,
|
||||||
transports/transport,
|
transports/transport,
|
||||||
stream/lpstream,
|
|
||||||
multistream,
|
multistream,
|
||||||
protocols/protocol,
|
protocols/protocol,
|
||||||
protocols/secure/secure,
|
protocols/secure/secure,
|
||||||
protocols/secure/plaintext, # for plain text
|
protocols/secure/plaintext, # for plain text
|
||||||
peerinfo,
|
peerinfo,
|
||||||
multiaddress,
|
|
||||||
protocols/identify,
|
protocols/identify,
|
||||||
protocols/pubsub/pubsub,
|
protocols/pubsub/pubsub,
|
||||||
muxers/muxer,
|
muxers/muxer,
|
||||||
|
@ -26,6 +24,12 @@ import connection,
|
||||||
logScope:
|
logScope:
|
||||||
topic = "Switch"
|
topic = "Switch"
|
||||||
|
|
||||||
|
#TODO: General note - use a finite state machine to manage the different
|
||||||
|
# steps of connections establishing and upgrading. This makes everything
|
||||||
|
# more robust and less prone to ordering attacks - i.e. muxing can come if
|
||||||
|
# and only if the channel has been secured (i.e. if a secure manager has been
|
||||||
|
# previously provided)
|
||||||
|
|
||||||
type
|
type
|
||||||
NoPubSubException = object of CatchableError
|
NoPubSubException = object of CatchableError
|
||||||
|
|
||||||
|
@ -48,7 +52,6 @@ proc newNoPubSubException(): ref Exception {.inline.} =
|
||||||
proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
||||||
## secure the incoming connection
|
## secure the incoming connection
|
||||||
|
|
||||||
# plaintext for now, doesn't do anything
|
|
||||||
let managers = toSeq(s.secureManagers.keys)
|
let managers = toSeq(s.secureManagers.keys)
|
||||||
if managers.len == 0:
|
if managers.len == 0:
|
||||||
raise newException(CatchableError, "No secure managers registered!")
|
raise newException(CatchableError, "No secure managers registered!")
|
||||||
|
@ -62,20 +65,21 @@ proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
||||||
proc identify(s: Switch, conn: Connection): Future[PeerInfo] {.async, gcsafe.} =
|
proc identify(s: Switch, conn: Connection): Future[PeerInfo] {.async, gcsafe.} =
|
||||||
## identify the connection
|
## identify the connection
|
||||||
|
|
||||||
|
result = conn.peerInfo
|
||||||
try:
|
try:
|
||||||
if (await s.ms.select(conn, s.identity.codec)):
|
if (await s.ms.select(conn, s.identity.codec)):
|
||||||
let info = await s.identity.identify(conn, conn.peerInfo)
|
let info = await s.identity.identify(conn, conn.peerInfo)
|
||||||
|
|
||||||
if info.pubKey.isSome:
|
if info.pubKey.isSome:
|
||||||
result.peerId = some(PeerID.init(info.pubKey.get())) # we might not have a peerId at all
|
result.peerId = some(PeerID.init(info.pubKey.get())) # we might not have a peerId at all
|
||||||
|
trace "identify: identified remote peer", peer = result.id
|
||||||
|
|
||||||
if info.addrs.len > 0:
|
if info.addrs.len > 0:
|
||||||
result.addrs = info.addrs
|
result.addrs = info.addrs
|
||||||
|
|
||||||
if info.protos.len > 0:
|
if info.protos.len > 0:
|
||||||
result.protocols = info.protos
|
result.protocols = info.protos
|
||||||
|
|
||||||
trace "identify: identified remote peer ", peer = result.peerId.get().pretty
|
|
||||||
except IdentityInvalidMsgError as exc:
|
except IdentityInvalidMsgError as exc:
|
||||||
error "identify: invalid message", msg = exc.msg
|
error "identify: invalid message", msg = exc.msg
|
||||||
except IdentityNoMatchError as exc:
|
except IdentityNoMatchError as exc:
|
||||||
|
@ -100,22 +104,23 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} =
|
||||||
muxer.streamHandler = s.streamHandler
|
muxer.streamHandler = s.streamHandler
|
||||||
|
|
||||||
# new stream for identify
|
# new stream for identify
|
||||||
let stream = await muxer.newStream()
|
var stream = await muxer.newStream()
|
||||||
let handlerFut = muxer.handle()
|
let handlerFut = muxer.handle()
|
||||||
|
|
||||||
# add muxer handler cleanup proc
|
# add muxer handler cleanup proc
|
||||||
handlerFut.addCallback(
|
handlerFut.addCallback(
|
||||||
proc(udata: pointer = nil) {.gcsafe.} =
|
proc(udata: pointer = nil) {.gcsafe.} =
|
||||||
trace "mux: Muxer handler completed for peer ",
|
trace "muxer handler completed for peer",
|
||||||
peer = conn.peerInfo.peerId.get().pretty
|
peer = conn.peerInfo.id
|
||||||
)
|
)
|
||||||
|
|
||||||
# do identify first, so that we have a
|
# do identify first, so that we have a
|
||||||
# PeerInfo in case we didn't before
|
# PeerInfo in case we didn't before
|
||||||
conn.peerInfo = await s.identify(stream)
|
conn.peerInfo = await s.identify(stream)
|
||||||
await stream.close() # close idenity stream
|
|
||||||
|
await stream.close() # close identify stream
|
||||||
trace "connection's peerInfo", peerInfo = conn.peerInfo.peerId
|
|
||||||
|
trace "connection's peerInfo", peerInfo = conn.peerInfo
|
||||||
|
|
||||||
# store it in muxed connections if we have a peer for it
|
# store it in muxed connections if we have a peer for it
|
||||||
# TODO: We should make sure that this are cleaned up properly
|
# TODO: We should make sure that this are cleaned up properly
|
||||||
|
@ -123,43 +128,42 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} =
|
||||||
# happen once secio is in place, but still something to keep
|
# happen once secio is in place, but still something to keep
|
||||||
# in mind
|
# in mind
|
||||||
if conn.peerInfo.peerId.isSome:
|
if conn.peerInfo.peerId.isSome:
|
||||||
trace "adding muxer for peer", peer = conn.peerInfo.peerId.get().pretty
|
trace "adding muxer for peer", peer = conn.peerInfo.id
|
||||||
s.muxed[conn.peerInfo.peerId.get().pretty] = muxer
|
s.muxed[conn.peerInfo.id] = muxer
|
||||||
|
|
||||||
proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} =
|
proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} =
|
||||||
if conn.peerInfo.peerId.isSome:
|
if conn.peerInfo.peerId.isSome:
|
||||||
let id = conn.peerInfo.peerId.get().pretty
|
let id = conn.peerInfo.id
|
||||||
if s.muxed.contains(id):
|
trace "cleaning up connection for peer", peerId = id
|
||||||
await s.muxed[id].close
|
if id in s.muxed:
|
||||||
|
await s.muxed[id].close()
|
||||||
if s.connections.contains(id):
|
s.muxed.del(id)
|
||||||
|
|
||||||
|
if id in s.connections:
|
||||||
await s.connections[id].close()
|
await s.connections[id].close()
|
||||||
|
s.connections.del(id)
|
||||||
|
|
||||||
proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Option[Connection]] {.async, gcsafe.} =
|
proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Option[Connection]] {.async, gcsafe.} =
|
||||||
# if there is a muxer for the connection
|
# if there is a muxer for the connection
|
||||||
# use it instead to create a muxed stream
|
# use it instead to create a muxed stream
|
||||||
if s.muxed.contains(peerInfo.peerId.get().pretty):
|
if peerInfo.id in s.muxed:
|
||||||
trace "connection is muxed, retriving muxer and setting up a stream"
|
trace "connection is muxed, setting up a stream"
|
||||||
let muxer = s.muxed[peerInfo.peerId.get().pretty]
|
let muxer = s.muxed[peerInfo.id]
|
||||||
let conn = await muxer.newStream()
|
let conn = await muxer.newStream()
|
||||||
result = some(conn)
|
result = some(conn)
|
||||||
|
|
||||||
proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
||||||
trace "handling connection", conn = conn
|
trace "handling connection", conn = conn
|
||||||
result = conn
|
result = conn
|
||||||
## perform upgrade flow
|
|
||||||
if result.peerInfo.peerId.isSome:
|
|
||||||
let id = result.peerInfo.peerId.get().pretty
|
|
||||||
if s.connections.contains(id):
|
|
||||||
# if we already have a connection for this peer,
|
|
||||||
# close the incoming connection and return the
|
|
||||||
# existing one
|
|
||||||
await result.close()
|
|
||||||
return s.connections[id]
|
|
||||||
s.connections[id] = result
|
|
||||||
|
|
||||||
result = await s.secure(conn) # secure the connection
|
# don't mux/secure twise
|
||||||
|
if conn.peerInfo.peerId.isSome and
|
||||||
|
conn.peerInfo.id in s.muxed:
|
||||||
|
return
|
||||||
|
|
||||||
|
result = await s.secure(result) # secure the connection
|
||||||
await s.mux(result) # mux it if possible
|
await s.mux(result) # mux it if possible
|
||||||
|
s.connections[conn.peerInfo.id] = result
|
||||||
|
|
||||||
proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
|
proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
|
||||||
trace "upgrading incoming connection"
|
trace "upgrading incoming connection"
|
||||||
|
@ -192,42 +196,57 @@ proc dial*(s: Switch,
|
||||||
peer: PeerInfo,
|
peer: PeerInfo,
|
||||||
proto: string = ""):
|
proto: string = ""):
|
||||||
Future[Connection] {.async.} =
|
Future[Connection] {.async.} =
|
||||||
trace "dialing peer", peer = peer.peerId.get().pretty
|
let id = peer.id
|
||||||
|
trace "dialing peer", peer = id
|
||||||
for t in s.transports: # for each transport
|
for t in s.transports: # for each transport
|
||||||
for a in peer.addrs: # for each address
|
for a in peer.addrs: # for each address
|
||||||
if t.handles(a): # check if it can dial it
|
if t.handles(a): # check if it can dial it
|
||||||
result = await t.dial(a)
|
if id notin s.connections:
|
||||||
# make sure to assign the peer to the connection
|
trace "dialing address", address = $a
|
||||||
result.peerInfo = peer
|
result = await t.dial(a)
|
||||||
|
# make sure to assign the peer to the connection
|
||||||
|
result.peerInfo = peer
|
||||||
result = await s.upgradeOutgoing(result)
|
result = await s.upgradeOutgoing(result)
|
||||||
|
result.closeEvent.wait().addCallback(
|
||||||
|
proc(udata: pointer) =
|
||||||
|
asyncCheck s.cleanupConn(result)
|
||||||
|
)
|
||||||
|
|
||||||
let stream = await s.getMuxedStream(peer)
|
if proto.len > 0 and not result.closed:
|
||||||
if stream.isSome:
|
let stream = await s.getMuxedStream(peer)
|
||||||
trace "connection is muxed, return muxed stream"
|
if stream.isSome:
|
||||||
result = stream.get()
|
trace "connection is muxed, return muxed stream"
|
||||||
|
result = stream.get()
|
||||||
|
trace "attempting to select remote", proto = proto
|
||||||
|
|
||||||
trace "dial: attempting to select remote ", proto = proto
|
if not (await s.ms.select(result, proto)):
|
||||||
if not (await s.ms.select(result, proto)):
|
error "unable to select protocol: ", proto = proto
|
||||||
error "dial: Unable to select protocol: ", proto = proto
|
raise newException(CatchableError,
|
||||||
raise newException(CatchableError,
|
&"unable to select protocol: {proto}")
|
||||||
&"Unable to select protocol: {proto}")
|
|
||||||
|
break # don't dial more than one addr on the same transport
|
||||||
|
|
||||||
proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} =
|
proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} =
|
||||||
if isNil(proto.handler):
|
if isNil(proto.handler):
|
||||||
raise newException(CatchableError,
|
raise newException(CatchableError,
|
||||||
"Protocol has to define a handle method or proc")
|
"Protocol has to define a handle method or proc")
|
||||||
|
|
||||||
if proto.codec.len == 0:
|
if proto.codec.len == 0:
|
||||||
raise newException(CatchableError,
|
raise newException(CatchableError,
|
||||||
"Protocol has to define a codec string")
|
"Protocol has to define a codec string")
|
||||||
|
|
||||||
s.ms.addHandler(proto.codec, proto)
|
s.ms.addHandler(proto.codec, proto)
|
||||||
|
|
||||||
proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
|
proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
|
||||||
|
trace "starting switch"
|
||||||
|
|
||||||
proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} =
|
proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} =
|
||||||
try:
|
try:
|
||||||
await s.upgradeIncoming(conn) # perform upgrade on incoming connection
|
await s.upgradeIncoming(conn) # perform upgrade on incoming connection
|
||||||
|
except CatchableError as exc:
|
||||||
|
trace "exception occured", exc = exc.msg
|
||||||
finally:
|
finally:
|
||||||
|
await conn.close()
|
||||||
await s.cleanupConn(conn)
|
await s.cleanupConn(conn)
|
||||||
|
|
||||||
var startFuts: seq[Future[void]]
|
var startFuts: seq[Future[void]]
|
||||||
|
@ -237,10 +256,13 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
|
||||||
var server = await t.listen(a, handle)
|
var server = await t.listen(a, handle)
|
||||||
s.peerInfo.addrs[i] = t.ma # update peer's address
|
s.peerInfo.addrs[i] = t.ma # update peer's address
|
||||||
startFuts.add(server)
|
startFuts.add(server)
|
||||||
|
|
||||||
result = startFuts # listen for incoming connections
|
result = startFuts # listen for incoming connections
|
||||||
|
|
||||||
proc stop*(s: Switch) {.async.} =
|
proc stop*(s: Switch) {.async.} =
|
||||||
await allFutures(toSeq(s.connections.values).mapIt(it.close()))
|
trace "stopping switch"
|
||||||
|
|
||||||
|
await allFutures(toSeq(s.connections.values).mapIt(s.cleanupConn(it)))
|
||||||
await allFutures(s.transports.mapIt(it.close()))
|
await allFutures(s.transports.mapIt(it.close()))
|
||||||
|
|
||||||
proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =
|
proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =
|
||||||
|
@ -253,14 +275,14 @@ proc subscribe*(s: Switch, topic: string, handler: TopicHandler): Future[void] {
|
||||||
## subscribe to a pubsub topic
|
## subscribe to a pubsub topic
|
||||||
if s.pubSub.isNone:
|
if s.pubSub.isNone:
|
||||||
raise newNoPubSubException()
|
raise newNoPubSubException()
|
||||||
|
|
||||||
result = s.pubSub.get().subscribe(topic, handler)
|
result = s.pubSub.get().subscribe(topic, handler)
|
||||||
|
|
||||||
proc unsubscribe*(s: Switch, topics: seq[TopicPair]): Future[void] {.gcsafe.} =
|
proc unsubscribe*(s: Switch, topics: seq[TopicPair]): Future[void] {.gcsafe.} =
|
||||||
## unsubscribe from topics
|
## unsubscribe from topics
|
||||||
if s.pubSub.isNone:
|
if s.pubSub.isNone:
|
||||||
raise newNoPubSubException()
|
raise newNoPubSubException()
|
||||||
|
|
||||||
result = s.pubSub.get().unsubscribe(topics)
|
result = s.pubSub.get().unsubscribe(topics)
|
||||||
|
|
||||||
proc publish*(s: Switch, topic: string, data: seq[byte]): Future[void] {.gcsafe.} =
|
proc publish*(s: Switch, topic: string, data: seq[byte]): Future[void] {.gcsafe.} =
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
## This file may not be copied, modified, or distributed except according to
|
## This file may not be copied, modified, or distributed except according to
|
||||||
## those terms.
|
## those terms.
|
||||||
|
|
||||||
import chronos, chronicles
|
import chronos, chronicles, sequtils
|
||||||
import transport,
|
import transport,
|
||||||
../wire,
|
../wire,
|
||||||
../connection,
|
../connection,
|
||||||
|
@ -78,5 +78,5 @@ method dial*(t: TcpTransport,
|
||||||
result = await t.connHandler(t.server, client, true)
|
result = await t.connHandler(t.server, client, true)
|
||||||
|
|
||||||
method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} =
|
method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} =
|
||||||
## TODO: implement logic to properly discriminat TCP multiaddrs
|
if procCall Transport(t).handles(address):
|
||||||
true
|
result = address.protocols.filterIt( it == multiCodec("tcp") ).len > 0
|
||||||
|
|
|
@ -9,8 +9,7 @@
|
||||||
|
|
||||||
import sequtils
|
import sequtils
|
||||||
import chronos, chronicles
|
import chronos, chronicles
|
||||||
import ../peerinfo,
|
import ../connection,
|
||||||
../connection,
|
|
||||||
../multiaddress,
|
../multiaddress,
|
||||||
../multicodec
|
../multicodec
|
||||||
|
|
||||||
|
@ -62,9 +61,10 @@ method upgrade*(t: Transport) {.base, async, gcsafe.} =
|
||||||
|
|
||||||
method handles*(t: Transport, address: MultiAddress): bool {.base, gcsafe.} =
|
method handles*(t: Transport, address: MultiAddress): bool {.base, gcsafe.} =
|
||||||
## check if transport supportes the multiaddress
|
## check if transport supportes the multiaddress
|
||||||
# TODO: this should implement generic logic that would use the multicodec
|
|
||||||
# declared in the multicodec field and set by each individual transport
|
# by default we skip circuit addresses to avoid
|
||||||
discard
|
# having to repeat the check in every transport
|
||||||
|
address.protocols.filterIt( it == multiCodec("p2p-circuit") ).len == 0
|
||||||
|
|
||||||
method localAddress*(t: Transport): MultiAddress {.base, gcsafe.} =
|
method localAddress*(t: Transport): MultiAddress {.base, gcsafe.} =
|
||||||
## get the local address of the transport in case started with 0.0.0.0:0
|
## get the local address of the transport in case started with 0.0.0.0:0
|
||||||
|
|
|
@ -53,7 +53,17 @@ proc initVBuffer*(): VBuffer =
|
||||||
## Initialize empty VBuffer.
|
## Initialize empty VBuffer.
|
||||||
result.buffer = newSeqOfCap[byte](128)
|
result.buffer = newSeqOfCap[byte](128)
|
||||||
|
|
||||||
proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) =
|
proc writePBVarint*(vb: var VBuffer, value: PBSomeUVarint) =
|
||||||
|
## Write ``value`` as variable unsigned integer.
|
||||||
|
var length = 0
|
||||||
|
var v = value and cast[type(value)](0xFFFF_FFFF_FFFF_FFFF)
|
||||||
|
vb.buffer.setLen(len(vb.buffer) + vsizeof(v))
|
||||||
|
let res = PB.putUVarint(toOpenArray(vb.buffer, vb.offset, len(vb.buffer) - 1),
|
||||||
|
length, v)
|
||||||
|
doAssert(res == VarintStatus.Success)
|
||||||
|
vb.offset += length
|
||||||
|
|
||||||
|
proc writeLPVarint*(vb: var VBuffer, value: LPSomeUVarint) =
|
||||||
## Write ``value`` as variable unsigned integer.
|
## Write ``value`` as variable unsigned integer.
|
||||||
var length = 0
|
var length = 0
|
||||||
# LibP2P varint supports only 63 bits.
|
# LibP2P varint supports only 63 bits.
|
||||||
|
@ -64,6 +74,9 @@ proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) =
|
||||||
doAssert(res == VarintStatus.Success)
|
doAssert(res == VarintStatus.Success)
|
||||||
vb.offset += length
|
vb.offset += length
|
||||||
|
|
||||||
|
proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) =
|
||||||
|
writeLPVarint(vb, value)
|
||||||
|
|
||||||
proc writeSeq*[T: byte|char](vb: var VBuffer, value: openarray[T]) =
|
proc writeSeq*[T: byte|char](vb: var VBuffer, value: openarray[T]) =
|
||||||
## Write array ``value`` to buffer ``vb``, value will be prefixed with
|
## Write array ``value`` to buffer ``vb``, value will be prefixed with
|
||||||
## varint length of the array.
|
## varint length of the array.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import unittest, deques, sequtils, strformat
|
import unittest, strformat
|
||||||
import chronos
|
import chronos
|
||||||
import ../libp2p/stream/bufferstream
|
import ../libp2p/stream/bufferstream
|
||||||
|
|
||||||
|
@ -220,7 +220,6 @@ suite "BufferStream":
|
||||||
|
|
||||||
test "reads should happen in order":
|
test "reads should happen in order":
|
||||||
proc testWritePtr(): Future[bool] {.async.} =
|
proc testWritePtr(): Future[bool] {.async.} =
|
||||||
var count = 1
|
|
||||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||||
let buff = newBufferStream(writeHandler, 10)
|
let buff = newBufferStream(writeHandler, 10)
|
||||||
check buff.len == 0
|
check buff.len == 0
|
||||||
|
@ -245,3 +244,199 @@ suite "BufferStream":
|
||||||
|
|
||||||
check:
|
check:
|
||||||
waitFor(testWritePtr()) == true
|
waitFor(testWritePtr()) == true
|
||||||
|
|
||||||
|
test "pipe two streams without the `pipe` or `|` helpers":
|
||||||
|
proc pipeTest(): Future[bool] {.async.} =
|
||||||
|
proc writeHandler1(data: seq[byte]) {.async, gcsafe.}
|
||||||
|
proc writeHandler2(data: seq[byte]) {.async, gcsafe.}
|
||||||
|
|
||||||
|
var buf1 = newBufferStream(writeHandler1)
|
||||||
|
var buf2 = newBufferStream(writeHandler2)
|
||||||
|
|
||||||
|
proc writeHandler1(data: seq[byte]) {.async, gcsafe.} =
|
||||||
|
var msg = cast[string](data)
|
||||||
|
check msg == "Hello!"
|
||||||
|
await buf2.pushTo(data)
|
||||||
|
|
||||||
|
proc writeHandler2(data: seq[byte]) {.async, gcsafe.} =
|
||||||
|
var msg = cast[string](data)
|
||||||
|
check msg == "Hello!"
|
||||||
|
await buf1.pushTo(data)
|
||||||
|
|
||||||
|
var res1: seq[byte] = newSeq[byte](7)
|
||||||
|
var readFut1 = buf1.readExactly(addr res1[0], 7)
|
||||||
|
|
||||||
|
var res2: seq[byte] = newSeq[byte](7)
|
||||||
|
var readFut2 = buf2.readExactly(addr res2[0], 7)
|
||||||
|
|
||||||
|
await buf1.pushTo(cast[seq[byte]]("Hello2!"))
|
||||||
|
await buf2.pushTo(cast[seq[byte]]("Hello1!"))
|
||||||
|
|
||||||
|
await allFutures(readFut1, readFut2)
|
||||||
|
|
||||||
|
check:
|
||||||
|
res1 == cast[seq[byte]]("Hello2!")
|
||||||
|
res2 == cast[seq[byte]]("Hello1!")
|
||||||
|
|
||||||
|
result = true
|
||||||
|
|
||||||
|
check:
|
||||||
|
waitFor(pipeTest()) == true
|
||||||
|
|
||||||
|
test "pipe A -> B":
|
||||||
|
proc pipeTest(): Future[bool] {.async.} =
|
||||||
|
var buf1 = newBufferStream()
|
||||||
|
var buf2 = buf1.pipe(newBufferStream())
|
||||||
|
|
||||||
|
var res1: seq[byte] = newSeq[byte](7)
|
||||||
|
var readFut = buf2.readExactly(addr res1[0], 7)
|
||||||
|
await buf1.write(cast[seq[byte]]("Hello1!"))
|
||||||
|
await readFut
|
||||||
|
|
||||||
|
check:
|
||||||
|
res1 == cast[seq[byte]]("Hello1!")
|
||||||
|
|
||||||
|
result = true
|
||||||
|
|
||||||
|
check:
|
||||||
|
waitFor(pipeTest()) == true
|
||||||
|
|
||||||
|
test "pipe A -> B and B -> A":
|
||||||
|
proc pipeTest(): Future[bool] {.async.} =
|
||||||
|
var buf1 = newBufferStream()
|
||||||
|
var buf2 = newBufferStream()
|
||||||
|
|
||||||
|
buf1 = buf1.pipe(buf2).pipe(buf1)
|
||||||
|
|
||||||
|
var res1: seq[byte] = newSeq[byte](7)
|
||||||
|
var readFut1 = buf1.readExactly(addr res1[0], 7)
|
||||||
|
|
||||||
|
var res2: seq[byte] = newSeq[byte](7)
|
||||||
|
var readFut2 = buf2.readExactly(addr res2[0], 7)
|
||||||
|
|
||||||
|
await buf1.write(cast[seq[byte]]("Hello1!"))
|
||||||
|
await buf2.write(cast[seq[byte]]("Hello2!"))
|
||||||
|
await allFutures(readFut1, readFut2)
|
||||||
|
|
||||||
|
check:
|
||||||
|
res1 == cast[seq[byte]]("Hello2!")
|
||||||
|
res2 == cast[seq[byte]]("Hello1!")
|
||||||
|
|
||||||
|
result = true
|
||||||
|
|
||||||
|
check:
|
||||||
|
waitFor(pipeTest()) == true
|
||||||
|
|
||||||
|
test "pipe A -> A (echo)":
|
||||||
|
proc pipeTest(): Future[bool] {.async.} =
|
||||||
|
var buf1 = newBufferStream()
|
||||||
|
|
||||||
|
buf1 = buf1.pipe(buf1)
|
||||||
|
|
||||||
|
proc reader(): Future[seq[byte]] = buf1.read(6)
|
||||||
|
proc writer(): Future[void] = buf1.write(cast[seq[byte]]("Hello!"))
|
||||||
|
|
||||||
|
var writerFut = writer()
|
||||||
|
var readerFut = reader()
|
||||||
|
|
||||||
|
await writerFut
|
||||||
|
check:
|
||||||
|
(await readerFut) == cast[seq[byte]]("Hello!")
|
||||||
|
|
||||||
|
result = true
|
||||||
|
|
||||||
|
check:
|
||||||
|
waitFor(pipeTest()) == true
|
||||||
|
|
||||||
|
test "pipe with `|` operator - A -> B":
|
||||||
|
proc pipeTest(): Future[bool] {.async.} =
|
||||||
|
var buf1 = newBufferStream()
|
||||||
|
var buf2 = buf1 | newBufferStream()
|
||||||
|
|
||||||
|
var res1: seq[byte] = newSeq[byte](7)
|
||||||
|
var readFut = buf2.readExactly(addr res1[0], 7)
|
||||||
|
await buf1.write(cast[seq[byte]]("Hello1!"))
|
||||||
|
await readFut
|
||||||
|
|
||||||
|
check:
|
||||||
|
res1 == cast[seq[byte]]("Hello1!")
|
||||||
|
|
||||||
|
result = true
|
||||||
|
|
||||||
|
check:
|
||||||
|
waitFor(pipeTest()) == true
|
||||||
|
|
||||||
|
test "pipe with `|` operator - A -> B and B -> A":
|
||||||
|
proc pipeTest(): Future[bool] {.async.} =
|
||||||
|
var buf1 = newBufferStream()
|
||||||
|
var buf2 = newBufferStream()
|
||||||
|
|
||||||
|
buf1 = buf1 | buf2 | buf1
|
||||||
|
|
||||||
|
var res1: seq[byte] = newSeq[byte](7)
|
||||||
|
var readFut1 = buf1.readExactly(addr res1[0], 7)
|
||||||
|
|
||||||
|
var res2: seq[byte] = newSeq[byte](7)
|
||||||
|
var readFut2 = buf2.readExactly(addr res2[0], 7)
|
||||||
|
|
||||||
|
await buf1.write(cast[seq[byte]]("Hello1!"))
|
||||||
|
await buf2.write(cast[seq[byte]]("Hello2!"))
|
||||||
|
await allFutures(readFut1, readFut2)
|
||||||
|
|
||||||
|
check:
|
||||||
|
res1 == cast[seq[byte]]("Hello2!")
|
||||||
|
res2 == cast[seq[byte]]("Hello1!")
|
||||||
|
|
||||||
|
result = true
|
||||||
|
|
||||||
|
check:
|
||||||
|
waitFor(pipeTest()) == true
|
||||||
|
|
||||||
|
test "pipe with `|` operator - A -> A (echo)":
|
||||||
|
proc pipeTest(): Future[bool] {.async.} =
|
||||||
|
var buf1 = newBufferStream()
|
||||||
|
|
||||||
|
buf1 = buf1 | buf1
|
||||||
|
|
||||||
|
proc reader(): Future[seq[byte]] = buf1.read(6)
|
||||||
|
proc writer(): Future[void] = buf1.write(cast[seq[byte]]("Hello!"))
|
||||||
|
|
||||||
|
var writerFut = writer()
|
||||||
|
var readerFut = reader()
|
||||||
|
|
||||||
|
await writerFut
|
||||||
|
check:
|
||||||
|
(await readerFut) == cast[seq[byte]]("Hello!")
|
||||||
|
|
||||||
|
result = true
|
||||||
|
|
||||||
|
check:
|
||||||
|
waitFor(pipeTest()) == true
|
||||||
|
|
||||||
|
# TODO: Need to implement deadlock prevention when
|
||||||
|
# piping to self
|
||||||
|
test "pipe deadlock":
|
||||||
|
proc pipeTest(): Future[bool] {.async.} =
|
||||||
|
|
||||||
|
var buf1 = newBufferStream(size = 5)
|
||||||
|
|
||||||
|
buf1 = buf1 | buf1
|
||||||
|
|
||||||
|
var count = 30000
|
||||||
|
proc reader() {.async.} =
|
||||||
|
while count > 0:
|
||||||
|
discard await buf1.read(7)
|
||||||
|
|
||||||
|
proc writer() {.async.} =
|
||||||
|
while count > 0:
|
||||||
|
await buf1.write(cast[seq[byte]]("Hello2!"))
|
||||||
|
count.dec
|
||||||
|
|
||||||
|
var writerFut = writer()
|
||||||
|
var readerFut = reader()
|
||||||
|
|
||||||
|
await allFutures(readerFut, writerFut)
|
||||||
|
result = true
|
||||||
|
|
||||||
|
check:
|
||||||
|
waitFor(pipeTest()) == true
|
||||||
|
|
|
@ -274,25 +274,15 @@ suite "Mplex":
|
||||||
expect LPStreamClosedError:
|
expect LPStreamClosedError:
|
||||||
waitFor(testClosedForWrite())
|
waitFor(testClosedForWrite())
|
||||||
|
|
||||||
test "half closed - channel should close for read":
|
test "half closed - channel should close for read by remote":
|
||||||
proc testClosedForRead(): Future[void] {.async.} =
|
|
||||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
|
||||||
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
|
|
||||||
await chann.closedByRemote()
|
|
||||||
asyncDiscard chann.read()
|
|
||||||
|
|
||||||
expect LPStreamClosedError:
|
|
||||||
waitFor(testClosedForRead())
|
|
||||||
|
|
||||||
test "half closed - channel should close for read after eof":
|
|
||||||
proc testClosedForRead(): Future[void] {.async.} =
|
proc testClosedForRead(): Future[void] {.async.} =
|
||||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||||
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
|
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
|
||||||
|
|
||||||
await chann.pushTo(cast[seq[byte]]("Hello!"))
|
await chann.pushTo(cast[seq[byte]]("Hello!"))
|
||||||
await chann.close()
|
await chann.closedByRemote()
|
||||||
let msg = await chann.read()
|
discard await chann.read() # this should work, since there is data in the buffer
|
||||||
asyncDiscard chann.read()
|
discard await chann.read() # this should throw
|
||||||
|
|
||||||
expect LPStreamClosedError:
|
expect LPStreamClosedError:
|
||||||
waitFor(testClosedForRead())
|
waitFor(testClosedForRead())
|
||||||
|
@ -312,7 +302,7 @@ suite "Mplex":
|
||||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||||
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
|
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true)
|
||||||
await chann.reset()
|
await chann.reset()
|
||||||
asyncDiscard chann.read()
|
await chann.write(cast[seq[byte]]("Hello!"))
|
||||||
|
|
||||||
expect LPStreamClosedError:
|
expect LPStreamClosedError:
|
||||||
waitFor(testResetWrite())
|
waitFor(testResetWrite())
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import unittest, strutils, sequtils, sugar, strformat, options
|
import unittest, strutils, sequtils, strformat, options
|
||||||
import chronos
|
import chronos
|
||||||
import ../libp2p/connection,
|
import ../libp2p/connection,
|
||||||
../libp2p/multistream,
|
../libp2p/multistream,
|
||||||
|
@ -51,7 +51,8 @@ method write*(s: TestSelectStream, msg: seq[byte], msglen = -1)
|
||||||
method write*(s: TestSelectStream, msg: string, msglen = -1)
|
method write*(s: TestSelectStream, msg: string, msglen = -1)
|
||||||
{.async, gcsafe.} = discard
|
{.async, gcsafe.} = discard
|
||||||
|
|
||||||
method close(s: TestSelectStream) {.async, gcsafe.} = s.closed = true
|
method close(s: TestSelectStream) {.async, gcsafe.} =
|
||||||
|
s.isClosed = true
|
||||||
|
|
||||||
proc newTestSelectStream(): TestSelectStream =
|
proc newTestSelectStream(): TestSelectStream =
|
||||||
new result
|
new result
|
||||||
|
@ -97,7 +98,8 @@ method write*(s: TestLsStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
|
||||||
method write*(s: TestLsStream, msg: string, msglen = -1)
|
method write*(s: TestLsStream, msg: string, msglen = -1)
|
||||||
{.async, gcsafe.} = discard
|
{.async, gcsafe.} = discard
|
||||||
|
|
||||||
method close(s: TestLsStream) {.async, gcsafe.} = s.closed = true
|
method close(s: TestLsStream) {.async, gcsafe.} =
|
||||||
|
s.isClosed = true
|
||||||
|
|
||||||
proc newTestLsStream(ls: LsHandler): TestLsStream {.gcsafe.} =
|
proc newTestLsStream(ls: LsHandler): TestLsStream {.gcsafe.} =
|
||||||
new result
|
new result
|
||||||
|
@ -143,7 +145,8 @@ method write*(s: TestNaStream, msg: string, msglen = -1) {.async, gcsafe.} =
|
||||||
if s.step == 4:
|
if s.step == 4:
|
||||||
await s.na(msg)
|
await s.na(msg)
|
||||||
|
|
||||||
method close(s: TestNaStream) {.async, gcsafe.} = s.closed = true
|
method close(s: TestNaStream) {.async, gcsafe.} =
|
||||||
|
s.isClosed = true
|
||||||
|
|
||||||
proc newTestNaStream(na: NaHandler): TestNaStream =
|
proc newTestNaStream(na: NaHandler): TestNaStream =
|
||||||
new result
|
new result
|
||||||
|
|
|
@ -2,5 +2,11 @@ import unittest
|
||||||
import testvarint, testbase32, testbase58, testbase64
|
import testvarint, testbase32, testbase58, testbase64
|
||||||
import testrsa, testecnist, tested25519, testsecp256k1, testcrypto
|
import testrsa, testecnist, tested25519, testsecp256k1, testcrypto
|
||||||
import testmultibase, testmultihash, testmultiaddress, testcid, testpeer
|
import testmultibase, testmultihash, testmultiaddress, testcid, testpeer
|
||||||
import testtransport, testmultistream, testbufferstream,
|
|
||||||
testmplex, testidentify, testswitch, testpubsub
|
import testtransport,
|
||||||
|
testmultistream,
|
||||||
|
testbufferstream,
|
||||||
|
testidentify,
|
||||||
|
testswitch,
|
||||||
|
testpubsub,
|
||||||
|
testmplex
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import unittest, tables, options
|
import unittest, tables, options
|
||||||
import chronos, chronicles
|
import chronos
|
||||||
import ../libp2p/[switch,
|
import ../libp2p/[switch,
|
||||||
multistream,
|
multistream,
|
||||||
protocols/identify,
|
protocols/identify,
|
||||||
|
@ -36,7 +36,7 @@ method init(p: TestProto) {.gcsafe.} =
|
||||||
|
|
||||||
suite "Switch":
|
suite "Switch":
|
||||||
test "e2e use switch":
|
test "e2e use switch":
|
||||||
proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) =
|
proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) {.gcsafe.}=
|
||||||
let seckey = PrivateKey.random(RSA)
|
let seckey = PrivateKey.random(RSA)
|
||||||
var peerInfo: PeerInfo
|
var peerInfo: PeerInfo
|
||||||
peerInfo.peerId = some(PeerID.init(seckey))
|
peerInfo.peerId = some(PeerID.init(seckey))
|
||||||
|
@ -50,7 +50,11 @@ suite "Switch":
|
||||||
let transports = @[Transport(newTransport(TcpTransport))]
|
let transports = @[Transport(newTransport(TcpTransport))]
|
||||||
let muxers = [(MplexCodec, mplexProvider)].toTable()
|
let muxers = [(MplexCodec, mplexProvider)].toTable()
|
||||||
let secureManagers = [(SecioCodec, Secure(newSecio(seckey)))].toTable()
|
let secureManagers = [(SecioCodec, Secure(newSecio(seckey)))].toTable()
|
||||||
let switch = newSwitch(peerInfo, transports, identify, muxers, secureManagers)
|
let switch = newSwitch(peerInfo,
|
||||||
|
transports,
|
||||||
|
identify,
|
||||||
|
muxers,
|
||||||
|
secureManagers)
|
||||||
result = (switch, peerInfo)
|
result = (switch, peerInfo)
|
||||||
|
|
||||||
proc testSwitch(): Future[bool] {.async, gcsafe.} =
|
proc testSwitch(): Future[bool] {.async, gcsafe.} =
|
||||||
|
|
Loading…
Reference in New Issue