diff --git a/libp2p/connection.nim b/libp2p/connection.nim index 8eb925bd1..af3280975 100644 --- a/libp2p/connection.nim +++ b/libp2p/connection.nim @@ -7,7 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import chronos, options, chronicles +import chronos, chronicles import peerinfo, multiaddress, stream/lpstream, @@ -26,15 +26,28 @@ type InvalidVarintException = object of LPStreamError proc newInvalidVarintException*(): ref InvalidVarintException = - result = newException(InvalidVarintException, "unable to prase varint") + newException(InvalidVarintException, "unable to prase varint") proc newConnection*(stream: LPStream): Connection = ## create a new Connection for the specified async reader/writer new result result.stream = stream + result.closeEvent = newAsyncEvent() + + # bind stream's close event to connection's close + # to ensure correct close propagation + let this = result + if not isNil(result.stream.closeEvent): + result.stream.closeEvent.wait(). + addCallback( + proc (udata: pointer) = + if not this.closed: + trace "closing this connection because wrapped stream closed" + asyncCheck this.close() + ) method read*(s: Connection, n = -1): Future[seq[byte]] {.gcsafe.} = - result = s.stream.read(n) + s.stream.read(n) method readExactly*(s: Connection, pbytes: pointer, @@ -44,13 +57,13 @@ method readExactly*(s: Connection, method readLine*(s: Connection, limit = 0, - sep = "\r\n"): + sep = "\r\n"): Future[string] {.gcsafe.} = s.stream.readLine(limit, sep) method readOnce*(s: Connection, pbytes: pointer, - nbytes: int): + nbytes: int): Future[int] {.gcsafe.} = s.stream.readOnce(pbytes, nbytes) @@ -61,15 +74,15 @@ method readUntil*(s: Connection, Future[int] {.gcsafe.} = s.stream.readUntil(pbytes, nbytes, sep) -method write*(s: Connection, - pbytes: pointer, - nbytes: int): +method write*(s: Connection, + pbytes: pointer, + nbytes: int): Future[void] {.gcsafe.} = s.stream.write(pbytes, nbytes) -method write*(s: Connection, - msg: string, - msglen = -1): +method write*(s: Connection, + msg: string, + msglen = -1): Future[void] {.gcsafe.} = s.stream.write(msg, msglen) @@ -79,9 +92,20 @@ method write*(s: Connection, Future[void] {.gcsafe.} = s.stream.write(msg, msglen) +method closed*(s: Connection): bool = + if isNil(s.stream): + return false + + result = s.stream.closed + method close*(s: Connection) {.async, gcsafe.} = - await s.stream.close() - s.closed = true + trace "closing connection" + if not s.closed: + if not isNil(s.stream) and not s.stream.closed: + await s.stream.close() + s.closeEvent.fire() + s.isClosed = true + trace "connection closed", closed = s.closed proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} = ## read lenght prefixed msg @@ -100,21 +124,23 @@ proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} = raise newInvalidVarintException() result.setLen(size) if size > 0.uint: + trace "reading exact bytes from stream", size = size await s.readExactly(addr result[0], int(size)) - except LPStreamIncompleteError, LPStreamReadError: - trace "remote connection closed", exc = getCurrentExceptionMsg() + except LPStreamIncompleteError as exc: + trace "remote connection ended unexpectedly", exc = exc.msg + except LPStreamReadError as exc: + trace "couldn't read from stream", exc = exc.msg proc writeLp*(s: Connection, msg: string | seq[byte]): Future[void] {.gcsafe.} = ## write lenght prefixed var buf = initVBuffer() buf.writeSeq(msg) buf.finish() - result = s.write(buf.buffer) + s.write(buf.buffer) method getObservedAddrs*(c: Connection): Future[MultiAddress] {.base, async, gcsafe.} = ## get resolved multiaddresses for the connection result = c.observedAddrs proc `$`*(conn: Connection): string = - if conn.peerInfo.peerId.isSome: - result = $(conn.peerInfo.peerId.get()) + result = $(conn.peerInfo) diff --git a/libp2p/daemon/daemonapi.nim b/libp2p/daemon/daemonapi.nim index f9e64ade9..55eda1976 100644 --- a/libp2p/daemon/daemonapi.nim +++ b/libp2p/daemon/daemonapi.nim @@ -855,7 +855,7 @@ proc connect*(api: DaemonAPI, peer: PeerID, timeout)) pb.withMessage() do: discard - finally: + except: await api.closeConnection(transp) proc disconnect*(api: DaemonAPI, peer: PeerID) {.async.} = diff --git a/libp2p/multistream.nim b/libp2p/multistream.nim index c73bf8d9a..13087f143 100644 --- a/libp2p/multistream.nim +++ b/libp2p/multistream.nim @@ -7,12 +7,12 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import sequtils, strutils, strformat +import strutils import chronos, chronicles import connection, - varint, vbuffer, - protocols/protocol + protocols/protocol, + stream/lpstream logScope: topic = "Multistream" @@ -56,16 +56,16 @@ proc select*(m: MultisteamSelect, trace "selecting proto", proto = proto await conn.writeLp((proto[0] & "\n")) # select proto - result = cast[string](await conn.readLp()) # read ms header + result = cast[string]((await conn.readLp())) # read ms header result.removeSuffix("\n") if result != Codec: - trace "handshake failed", codec = result + trace "handshake failed", codec = result.toHex() return "" if proto.len() == 0: # no protocols, must be a handshake call return - result = cast[string](await conn.readLp()) # read the first proto + result = cast[string]((await conn.readLp())) # read the first proto trace "reading first requested proto" result.removeSuffix("\n") if result == proto[0]: @@ -76,7 +76,7 @@ proc select*(m: MultisteamSelect, trace "selecting one of several protos" for p in proto[1.. 0: list.add(s) @@ -111,8 +111,10 @@ proc list*(m: MultisteamSelect, proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} = trace "handle: starting multistream handling" - while not conn.closed: - var ms = cast[string](await conn.readLp()) + try: + while not conn.closed: + await sleepAsync(1.millis) + var ms = cast[string]((await conn.readLp())) ms.removeSuffix("\n") trace "handle: got request for ", ms @@ -142,11 +144,15 @@ proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} = try: await h.protocol.handler(conn, ms) return - except Exception as exc: - warn "exception while handling ", msg = exc.msg + except CatchableError as exc: + warn "exception while handling", msg = exc.msg return warn "no handlers for ", protocol = ms await conn.write(m.na) + except CatchableError as exc: + trace "exception occured", exc = exc.msg + finally: + trace "leaving multistream loop" proc addHandler*[T: LPProtocol](m: MultisteamSelect, codec: string, diff --git a/libp2p/muxers/mplex/coder.nim b/libp2p/muxers/mplex/coder.nim index 26c7bf9f1..b005d4c69 100644 --- a/libp2p/muxers/mplex/coder.nim +++ b/libp2p/muxers/mplex/coder.nim @@ -7,7 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import chronos, options, sequtils, strformat +import chronos, options import nimcrypto/utils, chronicles import types, ../../connection, @@ -29,31 +29,33 @@ proc readMplexVarint(conn: Connection): Future[Option[uint]] {.async, gcsafe.} = varint: uint length: int res: VarintStatus - var buffer = newSeq[byte](10) + buffer = newSeq[byte](10) + result = none(uint) try: for i in 0.. 0.uint: - trace "readMsg: read size varint ", varint = dataLenVarint data = await conn.read(dataLenVarint.get().int) + trace "read size varint", varint = dataLenVarint let header = headerVarint.get() result = some((header shr 3, MessageType(header and 0x7), data)) @@ -64,11 +66,13 @@ proc writeMsg*(conn: Connection, data: seq[byte] = @[]) {.async, gcsafe.} = ## write lenght prefixed var buf = initVBuffer() - let header = (id shl 3 or ord(msgType).uint) - buf.writeVarint(id shl 3 or ord(msgType).uint) - buf.writeVarint(data.len().uint) # size should be always sent + buf.writePBVarint(id shl 3 or ord(msgType).uint) + buf.writePBVarint(data.len().uint) # size should be always sent buf.finish() - await conn.write(buf.buffer & data) + try: + await conn.write(buf.buffer & data) + except LPStreamIncompleteError as exc: + trace "unable to send message", exc = exc.msg proc writeMsg*(conn: Connection, id: uint, diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index 0eb665b90..ca012fee8 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -7,7 +7,6 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import strformat import chronos, chronicles import types, coder, @@ -52,99 +51,110 @@ proc newChannel*(id: uint, result.asyncLock = newAsyncLock() let chan = result - proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} = + proc writeHandler(data: seq[byte]): Future[void] {.async.} = # writes should happen in sequence await chan.asyncLock.acquire() - trace "writeHandler: sending data ", data = data.toHex(), id = chan.id + trace "sending data ", data = data.toHex(), + id = chan.id, + initiator = chan.initiator + await conn.writeMsg(chan.id, chan.msgCode, data) # write header chan.asyncLock.release() result.initBufferStream(writeHandler, size) -proc closeMessage(s: LPChannel) {.async, gcsafe.} = +proc closeMessage(s: LPChannel) {.async.} = await s.conn.writeMsg(s.id, s.closeCode) # write header -proc closed*(s: LPChannel): bool = - s.closedLocal and s.closedLocal - proc closedByRemote*(s: LPChannel) {.async.} = s.closedRemote = true proc cleanUp*(s: LPChannel): Future[void] = + # method which calls the underlying buffer's `close` + # method used instead of `close` since it's overloaded to + # simulate half-closed streams result = procCall close(BufferStream(s)) +proc open*(s: LPChannel): Future[void] = + s.conn.writeMsg(s.id, MessageType.New, s.name) + method close*(s: LPChannel) {.async, gcsafe.} = s.closedLocal = true await s.closeMessage() -proc resetMessage(s: LPChannel) {.async, gcsafe.} = +proc resetMessage(s: LPChannel) {.async.} = await s.conn.writeMsg(s.id, s.resetCode) -proc resetByRemote*(s: LPChannel) {.async, gcsafe.} = +proc resetByRemote*(s: LPChannel) {.async.} = await allFutures(s.close(), s.closedByRemote()) s.isReset = true proc reset*(s: LPChannel) {.async.} = await allFutures(s.resetMessage(), s.resetByRemote()) -proc isReadEof(s: LPChannel): bool = - bool((s.closedRemote or s.closedLocal) and s.len() < 1) +method closed*(s: LPChannel): bool = + result = s.closedRemote and s.len == 0 -proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] {.gcsafe.} = - if s.closedRemote: +proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] = + if s.closedRemote or s.isReset: raise newLPStreamClosedError() + trace "pushing data to channel", data = data.toHex(), + id = s.id, + initiator = s.initiator + result = procCall pushTo(BufferStream(s), data) -method read*(s: LPChannel, n = -1): Future[seq[byte]] {.gcsafe.} = - if s.isReadEof(): +method read*(s: LPChannel, n = -1): Future[seq[byte]] = + if s.closed or s.isReset: raise newLPStreamClosedError() + result = procCall read(BufferStream(s), n) -method readExactly*(s: LPChannel, - pbytes: pointer, - nbytes: int): - Future[void] {.gcsafe.} = - if s.isReadEof(): +method readExactly*(s: LPChannel, + pbytes: pointer, + nbytes: int): + Future[void] = + if s.closed or s.isReset: raise newLPStreamClosedError() result = procCall readExactly(BufferStream(s), pbytes, nbytes) method readLine*(s: LPChannel, limit = 0, sep = "\r\n"): - Future[string] {.gcsafe.} = - if s.isReadEof(): + Future[string] = + if s.closed or s.isReset: raise newLPStreamClosedError() result = procCall readLine(BufferStream(s), limit, sep) method readOnce*(s: LPChannel, pbytes: pointer, nbytes: int): - Future[int] {.gcsafe.} = - if s.isReadEof(): + Future[int] = + if s.closed or s.isReset: raise newLPStreamClosedError() result = procCall readOnce(BufferStream(s), pbytes, nbytes) method readUntil*(s: LPChannel, pbytes: pointer, nbytes: int, sep: seq[byte]): - Future[int] {.gcsafe.} = - if s.isReadEof(): + Future[int] = + if s.closed or s.isReset: raise newLPStreamClosedError() result = procCall readOnce(BufferStream(s), pbytes, nbytes) method write*(s: LPChannel, pbytes: pointer, - nbytes: int): Future[void] {.gcsafe.} = - if s.closedLocal: + nbytes: int): Future[void] = + if s.closedLocal or s.isReset: raise newLPStreamClosedError() result = procCall write(BufferStream(s), pbytes, nbytes) -method write*(s: LPChannel, msg: string, msglen = -1) {.async, gcsafe.} = - if s.closedLocal: +method write*(s: LPChannel, msg: string, msglen = -1) {.async.} = + if s.closedLocal or s.isReset: raise newLPStreamClosedError() result = procCall write(BufferStream(s), msg, msglen) -method write*(s: LPChannel, msg: seq[byte], msglen = -1) {.async, gcsafe.} = - if s.closedLocal: +method write*(s: LPChannel, msg: seq[byte], msglen = -1) {.async.} = + if s.closedLocal or s.isReset: raise newLPStreamClosedError() result = procCall write(BufferStream(s), msg, msglen) diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index 7b8d760de..2f5328c36 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -11,16 +11,14 @@ ## Timeouts and message limits are still missing ## they need to be added ASAP -import tables, sequtils, options, strformat +import tables, sequtils, options import chronos, chronicles -import coder, types, lpchannel, - ../muxer, - ../../varint, +import ../muxer, ../../connection, - ../../vbuffer, - ../../protocols/protocol, - ../../stream/bufferstream, - ../../stream/lpstream + ../../stream/lpstream, + coder, + types, + lpchannel logScope: topic = "Mplex" @@ -34,9 +32,11 @@ type proc getChannelList(m: Mplex, initiator: bool): var Table[uint, LPChannel] = if initiator: - result = m.remote - else: + trace "picking local channels", initiator = initiator result = m.local + else: + trace "picking remote channels", initiator = initiator + result = m.remote proc newStreamInternal*(m: Mplex, initiator: bool = true, @@ -45,17 +45,28 @@ proc newStreamInternal*(m: Mplex, Future[LPChannel] {.async, gcsafe.} = ## create new channel/stream let id = if initiator: m.currentId.inc(); m.currentId else: chanId + trace "creating new channel", channelId = id, initiator = initiator result = newChannel(id, m.connection, initiator, name) m.getChannelList(initiator)[id] = result +proc cleanupChann(m: Mplex, chann: LPChannel, initiator: bool) {.async, inline.} = + ## call the channel's `close` to signal the + ## remote that the channel is closing + if not isNil(chann) and not chann.closed: + await chann.close() + await chann.cleanUp() + m.getChannelList(initiator).del(chann.id) + trace "cleaned up channel", id = chann.id + method handle*(m: Mplex) {.async, gcsafe.} = trace "starting mplex main loop" try: while not m.connection.closed: + trace "waiting for data" let msg = await m.connection.readMsg() if msg.isNone: # TODO: allow poll with timeout to avoid using `sleepAsync` - await sleepAsync(10.millis) + await sleepAsync(1.millis) continue let (id, msgType, data) = msg.get() @@ -63,8 +74,11 @@ method handle*(m: Mplex) {.async, gcsafe.} = var channel: LPChannel if MessageType(msgType) != MessageType.New: let channels = m.getChannelList(initiator) - if not channels.contains(id): - trace "handle: Channel with id and msg type ", id = id, msg = msgType + if id notin channels: + trace "Channel not found, skipping", id = id, + initiator = initiator, + msg = msgType + await sleepAsync(1.millis) continue channel = channels[id] @@ -72,36 +86,44 @@ method handle*(m: Mplex) {.async, gcsafe.} = of MessageType.New: let name = cast[string](data) channel = await m.newStreamInternal(false, id, name) - trace "handle: created channel ", id = id, name = name + trace "created channel", id = id, name = name, inititator = true if not isNil(m.streamHandler): let stream = newConnection(channel) stream.peerInfo = m.connection.peerInfo - let handlerFut = m.streamHandler(stream) - # channel cleanup routine - proc cleanUpChan(udata: pointer) {.gcsafe.} = - if handlerFut.finished: - channel.close().addCallback( - proc(udata: pointer) = - channel.cleanUp() - .addCallback(proc(udata: pointer) = - trace "handle: cleaned up channel ", id = id)) - handlerFut.addCallback(cleanUpChan) + # cleanup channel once handler is finished + # stream.closeEvent.wait().addCallback( + # proc(udata: pointer) = + # asyncCheck cleanupChann(m, channel, initiator)) + + asyncCheck m.streamHandler(stream) + continue of MessageType.MsgIn, MessageType.MsgOut: - trace "handle: pushing data to channel ", id = id, msgType = msgType + trace "pushing data to channel", id = id, + initiator = initiator, + msgType = msgType + await channel.pushTo(data) of MessageType.CloseIn, MessageType.CloseOut: - trace "handle: closing channel ", id = id, msgType = msgType + trace "closing channel", id = id, + initiator = initiator, + msgType = msgType + await channel.closedByRemote() m.getChannelList(initiator).del(id) of MessageType.ResetIn, MessageType.ResetOut: - trace "handle: resetting channel ", id = id + trace "resetting channel", id = id, + initiator = initiator, + msgType = msgType + await channel.resetByRemote() + m.getChannelList(initiator).del(id) break - except: - error "exception occurred", exception = getCurrentExceptionMsg() + except CatchableError as exc: + trace "exception occurred", exception = exc.msg finally: + trace "stopping mplex main loop" await m.connection.close() proc newMplex*(conn: Connection, @@ -112,13 +134,20 @@ proc newMplex*(conn: Connection, result.remote = initTable[uint, LPChannel]() result.local = initTable[uint, LPChannel]() + let m = result + conn.closeEvent.wait().addCallback( + proc(udata: pointer) = + asyncCheck m.close() + ) + method newStream*(m: Mplex, name: string = ""): Future[Connection] {.async, gcsafe.} = let channel = await m.newStreamInternal() - await m.connection.writeMsg(channel.id, MessageType.New, name) + # TODO: open the channel (this should be lazy) + await channel.open() result = newConnection(channel) result.peerInfo = m.connection.peerInfo method close*(m: Mplex) {.async, gcsafe.} = - await allFutures(@[allFutures(toSeq(m.remote.values).mapIt(it.close())), - allFutures(toSeq(m.local.values).mapIt(it.close()))]) - m.connection.reset() + trace "closing mplex muxer" + await allFutures(@[allFutures(toSeq(m.remote.values).mapIt(it.reset())), + allFutures(toSeq(m.local.values).mapIt(it.reset()))]) diff --git a/libp2p/muxers/mplex/types.nim b/libp2p/muxers/mplex/types.nim index af09049c0..559e20249 100644 --- a/libp2p/muxers/mplex/types.nim +++ b/libp2p/muxers/mplex/types.nim @@ -8,7 +8,6 @@ ## those terms. import chronos -import ../../connection const MaxMsgSize* = 1 shl 20 # 1mb const MaxChannels* = 1000 diff --git a/libp2p/peerinfo.nim b/libp2p/peerinfo.nim index 6e529217c..3d21fc017 100644 --- a/libp2p/peerinfo.nim +++ b/libp2p/peerinfo.nim @@ -10,7 +10,27 @@ import options import peer, multiaddress -type PeerInfo* = object of RootObj - peerId*: Option[PeerID] - addrs*: seq[MultiAddress] - protocols*: seq[string] +type + PeerInfo* = object of RootObj + peerId*: Option[PeerID] + addrs*: seq[MultiAddress] + protocols*: seq[string] + +proc id*(p: PeerInfo): string = + if p.peerId.isSome: + result = p.peerId.get().pretty + +proc `$`*(p: PeerInfo): string = + if p.peerId.isSome: + result.add("PeerID: ") + result.add(p.id & "\n") + + if p.addrs.len > 0: + result.add("Peer Addrs: ") + for a in p.addrs: + result.add($a & "\n") + + if p.protocols.len > 0: + result.add("Protocols: ") + for proto in p.protocols: + result.add(proto & "\n") diff --git a/libp2p/protocols/identify.nim b/libp2p/protocols/identify.nim index 8ae65efc3..e060edff2 100644 --- a/libp2p/protocols/identify.nim +++ b/libp2p/protocols/identify.nim @@ -7,7 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import options, strformat +import options import chronos, chronicles import ../protobuf/minprotobuf, ../peerinfo, @@ -115,14 +115,14 @@ method init*(p: Identify) = trace "handling identify request" var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs()) await conn.writeLp(pb.buffer) + # await conn.close() #TODO: investigate why this breaks p.handler = handle p.codec = IdentifyCodec proc identify*(p: Identify, conn: Connection, - remotePeerInfo: PeerInfo): - Future[IdentifyInfo] {.async.} = + remotePeerInfo: PeerInfo): Future[IdentifyInfo] {.async, gcsafe.} = var message = await conn.readLp() if len(message) == 0: trace "identify: Invalid or empty message received!" @@ -139,7 +139,7 @@ proc identify*(p: Identify, if peer != remotePeerInfo.peerId.get(): trace "Peer ids don't match", remote = peer.pretty(), - local = remotePeerInfo.peerId.get().pretty() + local = remotePeerInfo.id raise newException(IdentityNoMatchError, "Peer ids don't match") @@ -149,5 +149,4 @@ proc identify*(p: Identify, proc push*(p: Identify, conn: Connection) {.async.} = await conn.write(IdentifyPushCodec) var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs()) - let length = pb.getLen() await conn.writeLp(pb.buffer) diff --git a/libp2p/protocols/protocol.nim b/libp2p/protocols/protocol.nim index d5e280880..7c4cbb27d 100644 --- a/libp2p/protocols/protocol.nim +++ b/libp2p/protocols/protocol.nim @@ -8,9 +8,7 @@ ## those terms. import chronos -import ../connection, - ../peerinfo, - ../multiaddress +import ../connection type LPProtoHandler* = proc (conn: Connection, diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index 2d9ef3f98..a834ec624 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -14,6 +14,7 @@ import rpcmsg, ../../peer, ../../peerinfo, ../../connection, + ../../stream/lpstream, ../../crypto/crypto, ../../protobuf/minprotobuf @@ -45,7 +46,7 @@ proc handle*(p: PubSubPeer) {.async, gcsafe.} = trace "Decoded msg from peer", peer = p.id, msg = msg await p.handler(p, @[msg]) except: - error "An exception occured while processing pubsub rpc requests", exc = getCurrentExceptionMsg() + trace "An exception occured while processing pubsub rpc requests", exc = getCurrentExceptionMsg() finally: trace "closing connection to pubsub peer", peer = p.id await p.conn.close() diff --git a/libp2p/protocols/secure/plaintext.nim b/libp2p/protocols/secure/plaintext.nim index f1bf2a057..5841fc4c7 100644 --- a/libp2p/protocols/secure/plaintext.nim +++ b/libp2p/protocols/secure/plaintext.nim @@ -8,8 +8,7 @@ ## those terms. import chronos -import secure, - ../../connection +import secure, ../../connection const PlainTextCodec* = "/plaintext/1.0.0" diff --git a/libp2p/protocols/secure/secio.nim b/libp2p/protocols/secure/secio.nim index c473ba1f0..59ff745d7 100644 --- a/libp2p/protocols/secure/secio.nim +++ b/libp2p/protocols/secure/secio.nim @@ -6,10 +6,12 @@ ## at your option. ## This file may not be copied, modified, or distributed except according to ## those terms. +import options import chronos, chronicles import nimcrypto/[sysrand, hmac, sha2, sha, hash, rijndael, twofish, bcmode] import secure, ../../connection, + ../../stream/lpstream, ../../crypto/crypto, ../../crypto/ecnist, ../../protobuf/minprotobuf, @@ -60,7 +62,6 @@ type ctxsha1: HMAC[sha1] SecureConnection* = ref object of Connection - conn*: Connection writerMac: SecureMac readerMac: SecureMac writerCoder: SecureCipher @@ -176,13 +177,13 @@ proc readMessage*(sconn: SecureConnection): Future[seq[byte]] {.async.} = ## Read message from channel secure connection ``sconn``. try: var buf = newSeq[byte](4) - await sconn.conn.readExactly(addr buf[0], 4) + await sconn.readExactly(addr buf[0], 4) let length = (int(buf[0]) shl 24) or (int(buf[1]) shl 16) or (int(buf[2]) shl 8) or (int(buf[3])) trace "Recieved message header", header = toHex(buf), length = length if length <= SecioMaxMessageSize: buf.setLen(length) - await sconn.conn.readExactly(addr buf[0], length) + await sconn.readExactly(addr buf[0], length) trace "Received message body", length = length, buffer = toHex(buf) if sconn.macCheckAndDecode(buf): @@ -213,21 +214,27 @@ proc writeMessage*(sconn: SecureConnection, message: seq[byte]) {.async.} = msg[3] = byte(length and 0xFF) trace "Writing message", message = toHex(msg) try: - await sconn.conn.write(msg) + await sconn.write(msg) except AsyncStreamWriteError: trace "Could not write to connection" -proc newSecureConnection*(conn: Connection, hash: string, cipher: string, +proc newSecureConnection*(conn: Connection, + hash: string, + cipher: string, secrets: Secret, - order: int): SecureConnection = + order: int, + peerId: PeerID): SecureConnection = ## Create new secure connection, using specified hash algorithm ``hash``, ## cipher algorithm ``cipher``, stretched keys ``secrets`` and order ## ``order``. new result + + result.stream = conn + result.closeEvent = newAsyncEvent() + let i0 = if order < 0: 1 else: 0 let i1 = if order < 0: 0 else: 1 - result.conn = conn trace "Writer credentials", mackey = toHex(secrets.macOpenArray(i0)), enckey = toHex(secrets.keyOpenArray(i0)), iv = toHex(secrets.ivOpenArray(i0)) @@ -241,6 +248,8 @@ proc newSecureConnection*(conn: Connection, hash: string, cipher: string, result.readerCoder.init(cipher, secrets.keyOpenArray(i1), secrets.ivOpenArray(i1)) + result.peerInfo.peerId = some(peerId) + proc transactMessage(conn: Connection, msg: seq[byte]): Future[seq[byte]] {.async.} = var buf = newSeq[byte](4) @@ -281,7 +290,6 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.} remoteHashes: string remotePeerId: PeerID localPeerId: PeerID - ekey: PrivateKey localBytesPubkey = s.localPublicKey.getBytes() if randomBytes(localNonce) != SecioNonceSize: @@ -388,7 +396,8 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.} # Perform Nonce exchange over encrypted channel. - result = newSecureConnection(conn, hash, cipher, keys, order) + result = newSecureConnection(conn, hash, cipher, keys, order, remotePeerId) + await result.writeMessage(remoteNonce) var res = await result.readMessage() @@ -400,17 +409,21 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.} trace "Secure handshake succeeded" proc readLoop(sconn: SecureConnection, stream: BufferStream) {.async.} = - while not sconn.conn.closed: - try: + try: + while not sconn.closed: let msg = await sconn.readMessage() - await stream.pushTo(msg) - except CatchableError as exc: - trace "exception in secio", exc = exc.msg - return - finally: - trace "ending secio readLoop" + if msg.len > 0: + await stream.pushTo(msg) + + # tight loop, give a chance for other + # stuff to run as well + await sleepAsync(1.millis) + except CatchableError as exc: + trace "exception occured", exc = exc.msg + finally: + trace "ending secio readLoop", isclosed = sconn.closed() -proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async.} = +proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async, gcsafe.} = var sconn = await s.handshake(conn) proc writeHandler(data: seq[byte]) {.async, gcsafe.} = trace "sending encrypted bytes", bytes = data.toHex() @@ -419,7 +432,13 @@ proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async.} = var stream = newBufferStream(writeHandler) asyncCheck readLoop(sconn, stream) var secured = newConnection(stream) - secured.peerInfo = sconn.conn.peerInfo + secured.closeEvent.wait() + .addCallback(proc(udata: pointer) = + trace "wrapped connection closed, closing upstream" + if not sconn.closed: + asyncCheck sconn.close() + ) + secured.peerInfo.peerId = sconn.peerInfo.peerId result = secured method init(s: Secio) {.gcsafe.} = diff --git a/libp2p/stream/bufferstream.nim b/libp2p/stream/bufferstream.nim index 33d2fd502..3aba26dab 100644 --- a/libp2p/stream/bufferstream.nim +++ b/libp2p/stream/bufferstream.nim @@ -8,7 +8,7 @@ ## those terms. ## This module implements an asynchronous buffer stream -## which emulates physical async IO. +## which emulates physical async IO. ## ## The stream is based on the standard library's `Deque`, ## which is itself based on a ring buffer. @@ -25,12 +25,12 @@ ## ordered and asynchronous. Reads are queued up in order ## and are suspended when not enough data available. This ## allows preserving backpressure while maintaining full -## asynchrony. Both writting to the internal buffer with +## asynchrony. Both writting to the internal buffer with ## ``pushTo`` as well as reading with ``read*` methods, ## will suspend until either the amount of elements in the ## buffer goes below ``maxSize`` or more data becomes available. -import deques, tables, sequtils, math +import deques, math import chronos import ../stream/lpstream @@ -38,33 +38,49 @@ const DefaultBufferSize* = 1024 type # TODO: figure out how to make this generic to avoid casts - WriteHandler* = proc (data: seq[byte]): Future[void] {.gcsafe.} + WriteHandler* = proc (data: seq[byte]): Future[void] BufferStream* = ref object of LPStream maxSize*: int # buffer's max size in bytes - readBuf: Deque[byte] # a deque is based on a ring buffer + readBuf: Deque[byte] # this is a ring buffer based dequeue, this makes it perfect as the backing store here readReqs: Deque[Future[void]] # use dequeue to fire reads in order dataReadEvent: AsyncEvent writeHandler*: WriteHandler + lock: AsyncLock + isPiped: bool -proc requestReadBytes(s: BufferStream): Future[void] = + AlreadyPipedError* = object of CatchableError + NotWritableError* = object of CatchableError + +proc newAlreadyPipedError*(): ref Exception {.inline.} = + result = newException(AlreadyPipedError, "stream already piped") + +proc newNotWritableError*(): ref Exception {.inline.} = + result = newException(NotWritableError, "stream is not writable") + +proc requestReadBytes(s: BufferStream): Future[void] = ## create a future that will complete when more ## data becomes available in the read buffer result = newFuture[void]() s.readReqs.addLast(result) -proc initBufferStream*(s: BufferStream, handler: WriteHandler, size: int = DefaultBufferSize) = +proc initBufferStream*(s: BufferStream, + handler: WriteHandler = nil, + size: int = DefaultBufferSize) = s.maxSize = if isPowerOfTwo(size): size else: nextPowerOfTwo(size) s.readBuf = initDeque[byte](s.maxSize) s.readReqs = initDeque[Future[void]]() s.dataReadEvent = newAsyncEvent() + s.lock = newAsyncLock() s.writeHandler = handler + s.closeEvent = newAsyncEvent() -proc newBufferStream*(handler: WriteHandler, size: int = DefaultBufferSize): BufferStream = +proc newBufferStream*(handler: WriteHandler = nil, + size: int = DefaultBufferSize): BufferStream = new result result.initBufferStream(handler, size) -proc popFirst*(s: BufferStream): byte = +proc popFirst*(s: BufferStream): byte = result = s.readBuf.popFirst() s.dataReadEvent.fire() @@ -78,15 +94,24 @@ proc shrink(s: BufferStream, fromFirst = 0, fromLast = 0) = proc len*(s: BufferStream): int = s.readBuf.len -proc pushTo*(s: BufferStream, data: seq[byte]) {.async, gcsafe.} = +proc pushTo*(s: BufferStream, data: seq[byte]) {.async.} = ## Write bytes to internal read buffer, use this to fill up the ## buffer with data. ## ## This method is async and will wait until all data has been ## written to the internal buffer; this is done so that backpressure ## is preserved. + ## + + await s.lock.acquire() var index = 0 while true: + + # give readers a chance free up the buffer + # it it's full. + if s.readBuf.len >= s.maxSize: + await sleepAsync(10.millis) + while index < data.len and s.readBuf.len < s.maxSize: s.readBuf.addLast(data[index]) inc(index) @@ -94,18 +119,20 @@ proc pushTo*(s: BufferStream, data: seq[byte]) {.async, gcsafe.} = # resolve the next queued read request if s.readReqs.len > 0: s.readReqs.popFirst().complete() - + if index >= data.len: break - + # if we couldn't transfer all the data to the # internal buf wait on a read event await s.dataReadEvent.wait() + s.lock.release() -method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async, gcsafe.} = +method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async.} = ## Read all bytes (n <= 0) or exactly `n` bytes from buffer ## ## This procedure allocates buffer seq[byte] and return it as result. + ## var size = if n > 0: n else: s.readBuf.len() var index = 0 while index < size: @@ -116,25 +143,26 @@ method read*(s: BufferStream, n = -1): Future[seq[byte]] {.async, gcsafe.} = if index < size: await s.requestReadBytes() -method readExactly*(s: BufferStream, - pbytes: pointer, - nbytes: int): - Future[void] {.async, gcsafe.} = +method readExactly*(s: BufferStream, + pbytes: pointer, + nbytes: int): + Future[void] {.async.} = ## Read exactly ``nbytes`` bytes from read-only stream ``rstream`` and store ## it to ``pbytes``. ## ## If EOF is received and ``nbytes`` is not yet read, the procedure ## will raise ``LPStreamIncompleteError``. - let buff = await s.read(nbytes) + ## + var buff = await s.read(nbytes) if nbytes > buff.len(): raise newLPStreamIncompleteError() - copyMem(pbytes, unsafeAddr buff[0], nbytes) + copyMem(pbytes, addr buff[0], nbytes) method readLine*(s: BufferStream, limit = 0, - sep = "\r\n"): - Future[string] {.async, gcsafe.} = + sep = "\r\n"): + Future[string] {.async.} = ## Read one line from read-only stream ``rstream``, where ``"line"`` is a ## sequence of bytes ending with ``sep`` (default is ``"\r\n"``). ## @@ -146,6 +174,7 @@ method readLine*(s: BufferStream, ## ## If ``limit`` more then 0, then result string will be limited to ``limit`` ## bytes. + ## result = "" var lim = if limit <= 0: -1 else: limit var state = 0 @@ -170,14 +199,15 @@ method readLine*(s: BufferStream, method readOnce*(s: BufferStream, pbytes: pointer, nbytes: int): - Future[int] {.async, gcsafe.} = + Future[int] {.async.} = ## Perform one read operation on read-only stream ``rstream``. ## ## If internal buffer is not empty, ``nbytes`` bytes will be transferred from ## internal buffer, otherwise it will wait until some bytes will be received. + ## if s.readBuf.len == 0: await s.requestReadBytes() - + var len = if nbytes > s.readBuf.len: s.readBuf.len else: nbytes await s.readExactly(pbytes, len) result = len @@ -186,7 +216,7 @@ method readUntil*(s: BufferStream, pbytes: pointer, nbytes: int, sep: seq[byte]): - Future[int] {.async, gcsafe.} = + Future[int] {.async.} = ## Read data from the read-only stream ``rstream`` until separator ``sep`` is ## found. ## @@ -200,6 +230,7 @@ method readUntil*(s: BufferStream, ## will raise ``LPStreamLimitError``. ## ## Procedure returns actual number of bytes read. + ## var dest = cast[ptr UncheckedArray[byte]](pbytes) state = 0 @@ -231,22 +262,22 @@ method readUntil*(s: BufferStream, else: s.shrink(datalen) -method write*(s: BufferStream, - pbytes: pointer, - nbytes: int): Future[void] - {.gcsafe.} = +method write*(s: BufferStream, + pbytes: pointer, + nbytes: int): Future[void] = ## Consume (discard) all bytes (n <= 0) or ``n`` bytes from read-only stream ## ``rstream``. ## ## Return number of bytes actually consumed (discarded). + ## var buf: seq[byte] = newSeq[byte](nbytes) copyMem(addr buf[0], pbytes, nbytes) - result = s.writeHandler(buf) + if not isNil(s.writeHandler): + result = s.writeHandler(buf) method write*(s: BufferStream, msg: string, - msglen = -1): Future[void] - {.gcsafe.} = + msglen = -1): Future[void] = ## Write string ``sbytes`` of length ``msglen`` to writer stream ``wstream``. ## ## String ``sbytes`` must not be zero-length. @@ -254,14 +285,15 @@ method write*(s: BufferStream, ## If ``msglen < 0`` whole string ``sbytes`` will be writen to stream. ## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to ## stream. + ## var buf = "" shallowCopy(buf, if msglen > 0: msg[0.. len(sbytes)`` only ``len(sbytes)`` bytes will be written to ## stream. + ## var buf: seq[byte] shallowCopy(buf, if msglen > 0: msg[0.. 0: result.addrs = info.addrs - + if info.protos.len > 0: result.protocols = info.protos - trace "identify: identified remote peer ", peer = result.peerId.get().pretty except IdentityInvalidMsgError as exc: error "identify: invalid message", msg = exc.msg except IdentityNoMatchError as exc: @@ -100,22 +104,23 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = muxer.streamHandler = s.streamHandler # new stream for identify - let stream = await muxer.newStream() + var stream = await muxer.newStream() let handlerFut = muxer.handle() # add muxer handler cleanup proc handlerFut.addCallback( proc(udata: pointer = nil) {.gcsafe.} = - trace "mux: Muxer handler completed for peer ", - peer = conn.peerInfo.peerId.get().pretty + trace "muxer handler completed for peer", + peer = conn.peerInfo.id ) # do identify first, so that we have a # PeerInfo in case we didn't before conn.peerInfo = await s.identify(stream) - await stream.close() # close idenity stream - - trace "connection's peerInfo", peerInfo = conn.peerInfo.peerId + + await stream.close() # close identify stream + + trace "connection's peerInfo", peerInfo = conn.peerInfo # store it in muxed connections if we have a peer for it # TODO: We should make sure that this are cleaned up properly @@ -123,43 +128,42 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = # happen once secio is in place, but still something to keep # in mind if conn.peerInfo.peerId.isSome: - trace "adding muxer for peer", peer = conn.peerInfo.peerId.get().pretty - s.muxed[conn.peerInfo.peerId.get().pretty] = muxer + trace "adding muxer for peer", peer = conn.peerInfo.id + s.muxed[conn.peerInfo.id] = muxer proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = if conn.peerInfo.peerId.isSome: - let id = conn.peerInfo.peerId.get().pretty - if s.muxed.contains(id): - await s.muxed[id].close - - if s.connections.contains(id): + let id = conn.peerInfo.id + trace "cleaning up connection for peer", peerId = id + if id in s.muxed: + await s.muxed[id].close() + s.muxed.del(id) + + if id in s.connections: await s.connections[id].close() + s.connections.del(id) proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Option[Connection]] {.async, gcsafe.} = # if there is a muxer for the connection # use it instead to create a muxed stream - if s.muxed.contains(peerInfo.peerId.get().pretty): - trace "connection is muxed, retriving muxer and setting up a stream" - let muxer = s.muxed[peerInfo.peerId.get().pretty] + if peerInfo.id in s.muxed: + trace "connection is muxed, setting up a stream" + let muxer = s.muxed[peerInfo.id] let conn = await muxer.newStream() result = some(conn) proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = trace "handling connection", conn = conn result = conn - ## perform upgrade flow - if result.peerInfo.peerId.isSome: - let id = result.peerInfo.peerId.get().pretty - if s.connections.contains(id): - # if we already have a connection for this peer, - # close the incoming connection and return the - # existing one - await result.close() - return s.connections[id] - s.connections[id] = result - result = await s.secure(conn) # secure the connection + # don't mux/secure twise + if conn.peerInfo.peerId.isSome and + conn.peerInfo.id in s.muxed: + return + + result = await s.secure(result) # secure the connection await s.mux(result) # mux it if possible + s.connections[conn.peerInfo.id] = result proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = trace "upgrading incoming connection" @@ -192,42 +196,57 @@ proc dial*(s: Switch, peer: PeerInfo, proto: string = ""): Future[Connection] {.async.} = - trace "dialing peer", peer = peer.peerId.get().pretty + let id = peer.id + trace "dialing peer", peer = id for t in s.transports: # for each transport for a in peer.addrs: # for each address if t.handles(a): # check if it can dial it - result = await t.dial(a) - # make sure to assign the peer to the connection - result.peerInfo = peer + if id notin s.connections: + trace "dialing address", address = $a + result = await t.dial(a) + # make sure to assign the peer to the connection + result.peerInfo = peer result = await s.upgradeOutgoing(result) + result.closeEvent.wait().addCallback( + proc(udata: pointer) = + asyncCheck s.cleanupConn(result) + ) - let stream = await s.getMuxedStream(peer) - if stream.isSome: - trace "connection is muxed, return muxed stream" - result = stream.get() + if proto.len > 0 and not result.closed: + let stream = await s.getMuxedStream(peer) + if stream.isSome: + trace "connection is muxed, return muxed stream" + result = stream.get() + trace "attempting to select remote", proto = proto - trace "dial: attempting to select remote ", proto = proto - if not (await s.ms.select(result, proto)): - error "dial: Unable to select protocol: ", proto = proto - raise newException(CatchableError, - &"Unable to select protocol: {proto}") + if not (await s.ms.select(result, proto)): + error "unable to select protocol: ", proto = proto + raise newException(CatchableError, + &"unable to select protocol: {proto}") + + break # don't dial more than one addr on the same transport proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} = if isNil(proto.handler): - raise newException(CatchableError, + raise newException(CatchableError, "Protocol has to define a handle method or proc") if proto.codec.len == 0: - raise newException(CatchableError, + raise newException(CatchableError, "Protocol has to define a codec string") s.ms.addHandler(proto.codec, proto) proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = + trace "starting switch" + proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = try: await s.upgradeIncoming(conn) # perform upgrade on incoming connection + except CatchableError as exc: + trace "exception occured", exc = exc.msg finally: + await conn.close() await s.cleanupConn(conn) var startFuts: seq[Future[void]] @@ -237,10 +256,13 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = var server = await t.listen(a, handle) s.peerInfo.addrs[i] = t.ma # update peer's address startFuts.add(server) + result = startFuts # listen for incoming connections proc stop*(s: Switch) {.async.} = - await allFutures(toSeq(s.connections.values).mapIt(it.close())) + trace "stopping switch" + + await allFutures(toSeq(s.connections.values).mapIt(s.cleanupConn(it))) await allFutures(s.transports.mapIt(it.close())) proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = @@ -253,14 +275,14 @@ proc subscribe*(s: Switch, topic: string, handler: TopicHandler): Future[void] { ## subscribe to a pubsub topic if s.pubSub.isNone: raise newNoPubSubException() - + result = s.pubSub.get().subscribe(topic, handler) proc unsubscribe*(s: Switch, topics: seq[TopicPair]): Future[void] {.gcsafe.} = ## unsubscribe from topics if s.pubSub.isNone: raise newNoPubSubException() - + result = s.pubSub.get().unsubscribe(topics) proc publish*(s: Switch, topic: string, data: seq[byte]): Future[void] {.gcsafe.} = diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index a8bf39d2e..c507ddbbf 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -7,7 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import chronos, chronicles +import chronos, chronicles, sequtils import transport, ../wire, ../connection, @@ -78,5 +78,5 @@ method dial*(t: TcpTransport, result = await t.connHandler(t.server, client, true) method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} = - ## TODO: implement logic to properly discriminat TCP multiaddrs - true + if procCall Transport(t).handles(address): + result = address.protocols.filterIt( it == multiCodec("tcp") ).len > 0 diff --git a/libp2p/transports/transport.nim b/libp2p/transports/transport.nim index c10136943..64017b708 100644 --- a/libp2p/transports/transport.nim +++ b/libp2p/transports/transport.nim @@ -9,8 +9,7 @@ import sequtils import chronos, chronicles -import ../peerinfo, - ../connection, +import ../connection, ../multiaddress, ../multicodec @@ -62,9 +61,10 @@ method upgrade*(t: Transport) {.base, async, gcsafe.} = method handles*(t: Transport, address: MultiAddress): bool {.base, gcsafe.} = ## check if transport supportes the multiaddress - # TODO: this should implement generic logic that would use the multicodec - # declared in the multicodec field and set by each individual transport - discard + + # by default we skip circuit addresses to avoid + # having to repeat the check in every transport + address.protocols.filterIt( it == multiCodec("p2p-circuit") ).len == 0 method localAddress*(t: Transport): MultiAddress {.base, gcsafe.} = ## get the local address of the transport in case started with 0.0.0.0:0 diff --git a/libp2p/vbuffer.nim b/libp2p/vbuffer.nim index 1ee5b9d66..69e268b4a 100644 --- a/libp2p/vbuffer.nim +++ b/libp2p/vbuffer.nim @@ -53,7 +53,17 @@ proc initVBuffer*(): VBuffer = ## Initialize empty VBuffer. result.buffer = newSeqOfCap[byte](128) -proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) = +proc writePBVarint*(vb: var VBuffer, value: PBSomeUVarint) = + ## Write ``value`` as variable unsigned integer. + var length = 0 + var v = value and cast[type(value)](0xFFFF_FFFF_FFFF_FFFF) + vb.buffer.setLen(len(vb.buffer) + vsizeof(v)) + let res = PB.putUVarint(toOpenArray(vb.buffer, vb.offset, len(vb.buffer) - 1), + length, v) + doAssert(res == VarintStatus.Success) + vb.offset += length + +proc writeLPVarint*(vb: var VBuffer, value: LPSomeUVarint) = ## Write ``value`` as variable unsigned integer. var length = 0 # LibP2P varint supports only 63 bits. @@ -64,6 +74,9 @@ proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) = doAssert(res == VarintStatus.Success) vb.offset += length +proc writeVarint*(vb: var VBuffer, value: LPSomeUVarint) = + writeLPVarint(vb, value) + proc writeSeq*[T: byte|char](vb: var VBuffer, value: openarray[T]) = ## Write array ``value`` to buffer ``vb``, value will be prefixed with ## varint length of the array. diff --git a/tests/testbufferstream.nim b/tests/testbufferstream.nim index afbf20b2a..a388c1ff0 100644 --- a/tests/testbufferstream.nim +++ b/tests/testbufferstream.nim @@ -1,4 +1,4 @@ -import unittest, deques, sequtils, strformat +import unittest, strformat import chronos import ../libp2p/stream/bufferstream @@ -220,7 +220,6 @@ suite "BufferStream": test "reads should happen in order": proc testWritePtr(): Future[bool] {.async.} = - var count = 1 proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard let buff = newBufferStream(writeHandler, 10) check buff.len == 0 @@ -245,3 +244,199 @@ suite "BufferStream": check: waitFor(testWritePtr()) == true + + test "pipe two streams without the `pipe` or `|` helpers": + proc pipeTest(): Future[bool] {.async.} = + proc writeHandler1(data: seq[byte]) {.async, gcsafe.} + proc writeHandler2(data: seq[byte]) {.async, gcsafe.} + + var buf1 = newBufferStream(writeHandler1) + var buf2 = newBufferStream(writeHandler2) + + proc writeHandler1(data: seq[byte]) {.async, gcsafe.} = + var msg = cast[string](data) + check msg == "Hello!" + await buf2.pushTo(data) + + proc writeHandler2(data: seq[byte]) {.async, gcsafe.} = + var msg = cast[string](data) + check msg == "Hello!" + await buf1.pushTo(data) + + var res1: seq[byte] = newSeq[byte](7) + var readFut1 = buf1.readExactly(addr res1[0], 7) + + var res2: seq[byte] = newSeq[byte](7) + var readFut2 = buf2.readExactly(addr res2[0], 7) + + await buf1.pushTo(cast[seq[byte]]("Hello2!")) + await buf2.pushTo(cast[seq[byte]]("Hello1!")) + + await allFutures(readFut1, readFut2) + + check: + res1 == cast[seq[byte]]("Hello2!") + res2 == cast[seq[byte]]("Hello1!") + + result = true + + check: + waitFor(pipeTest()) == true + + test "pipe A -> B": + proc pipeTest(): Future[bool] {.async.} = + var buf1 = newBufferStream() + var buf2 = buf1.pipe(newBufferStream()) + + var res1: seq[byte] = newSeq[byte](7) + var readFut = buf2.readExactly(addr res1[0], 7) + await buf1.write(cast[seq[byte]]("Hello1!")) + await readFut + + check: + res1 == cast[seq[byte]]("Hello1!") + + result = true + + check: + waitFor(pipeTest()) == true + + test "pipe A -> B and B -> A": + proc pipeTest(): Future[bool] {.async.} = + var buf1 = newBufferStream() + var buf2 = newBufferStream() + + buf1 = buf1.pipe(buf2).pipe(buf1) + + var res1: seq[byte] = newSeq[byte](7) + var readFut1 = buf1.readExactly(addr res1[0], 7) + + var res2: seq[byte] = newSeq[byte](7) + var readFut2 = buf2.readExactly(addr res2[0], 7) + + await buf1.write(cast[seq[byte]]("Hello1!")) + await buf2.write(cast[seq[byte]]("Hello2!")) + await allFutures(readFut1, readFut2) + + check: + res1 == cast[seq[byte]]("Hello2!") + res2 == cast[seq[byte]]("Hello1!") + + result = true + + check: + waitFor(pipeTest()) == true + + test "pipe A -> A (echo)": + proc pipeTest(): Future[bool] {.async.} = + var buf1 = newBufferStream() + + buf1 = buf1.pipe(buf1) + + proc reader(): Future[seq[byte]] = buf1.read(6) + proc writer(): Future[void] = buf1.write(cast[seq[byte]]("Hello!")) + + var writerFut = writer() + var readerFut = reader() + + await writerFut + check: + (await readerFut) == cast[seq[byte]]("Hello!") + + result = true + + check: + waitFor(pipeTest()) == true + + test "pipe with `|` operator - A -> B": + proc pipeTest(): Future[bool] {.async.} = + var buf1 = newBufferStream() + var buf2 = buf1 | newBufferStream() + + var res1: seq[byte] = newSeq[byte](7) + var readFut = buf2.readExactly(addr res1[0], 7) + await buf1.write(cast[seq[byte]]("Hello1!")) + await readFut + + check: + res1 == cast[seq[byte]]("Hello1!") + + result = true + + check: + waitFor(pipeTest()) == true + + test "pipe with `|` operator - A -> B and B -> A": + proc pipeTest(): Future[bool] {.async.} = + var buf1 = newBufferStream() + var buf2 = newBufferStream() + + buf1 = buf1 | buf2 | buf1 + + var res1: seq[byte] = newSeq[byte](7) + var readFut1 = buf1.readExactly(addr res1[0], 7) + + var res2: seq[byte] = newSeq[byte](7) + var readFut2 = buf2.readExactly(addr res2[0], 7) + + await buf1.write(cast[seq[byte]]("Hello1!")) + await buf2.write(cast[seq[byte]]("Hello2!")) + await allFutures(readFut1, readFut2) + + check: + res1 == cast[seq[byte]]("Hello2!") + res2 == cast[seq[byte]]("Hello1!") + + result = true + + check: + waitFor(pipeTest()) == true + + test "pipe with `|` operator - A -> A (echo)": + proc pipeTest(): Future[bool] {.async.} = + var buf1 = newBufferStream() + + buf1 = buf1 | buf1 + + proc reader(): Future[seq[byte]] = buf1.read(6) + proc writer(): Future[void] = buf1.write(cast[seq[byte]]("Hello!")) + + var writerFut = writer() + var readerFut = reader() + + await writerFut + check: + (await readerFut) == cast[seq[byte]]("Hello!") + + result = true + + check: + waitFor(pipeTest()) == true + + # TODO: Need to implement deadlock prevention when + # piping to self + test "pipe deadlock": + proc pipeTest(): Future[bool] {.async.} = + + var buf1 = newBufferStream(size = 5) + + buf1 = buf1 | buf1 + + var count = 30000 + proc reader() {.async.} = + while count > 0: + discard await buf1.read(7) + + proc writer() {.async.} = + while count > 0: + await buf1.write(cast[seq[byte]]("Hello2!")) + count.dec + + var writerFut = writer() + var readerFut = reader() + + await allFutures(readerFut, writerFut) + result = true + + check: + waitFor(pipeTest()) == true diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 1bd44dd2e..09975f89d 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -274,25 +274,15 @@ suite "Mplex": expect LPStreamClosedError: waitFor(testClosedForWrite()) - test "half closed - channel should close for read": - proc testClosedForRead(): Future[void] {.async.} = - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard - let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) - await chann.closedByRemote() - asyncDiscard chann.read() - - expect LPStreamClosedError: - waitFor(testClosedForRead()) - - test "half closed - channel should close for read after eof": + test "half closed - channel should close for read by remote": proc testClosedForRead(): Future[void] {.async.} = proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) await chann.pushTo(cast[seq[byte]]("Hello!")) - await chann.close() - let msg = await chann.read() - asyncDiscard chann.read() + await chann.closedByRemote() + discard await chann.read() # this should work, since there is data in the buffer + discard await chann.read() # this should throw expect LPStreamClosedError: waitFor(testClosedForRead()) @@ -312,7 +302,7 @@ suite "Mplex": proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) await chann.reset() - asyncDiscard chann.read() + await chann.write(cast[seq[byte]]("Hello!")) expect LPStreamClosedError: waitFor(testResetWrite()) diff --git a/tests/testmultistream.nim b/tests/testmultistream.nim index 7c3c566c2..1c59466c4 100644 --- a/tests/testmultistream.nim +++ b/tests/testmultistream.nim @@ -1,4 +1,4 @@ -import unittest, strutils, sequtils, sugar, strformat, options +import unittest, strutils, sequtils, strformat, options import chronos import ../libp2p/connection, ../libp2p/multistream, @@ -51,7 +51,8 @@ method write*(s: TestSelectStream, msg: seq[byte], msglen = -1) method write*(s: TestSelectStream, msg: string, msglen = -1) {.async, gcsafe.} = discard -method close(s: TestSelectStream) {.async, gcsafe.} = s.closed = true +method close(s: TestSelectStream) {.async, gcsafe.} = + s.isClosed = true proc newTestSelectStream(): TestSelectStream = new result @@ -97,7 +98,8 @@ method write*(s: TestLsStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} = method write*(s: TestLsStream, msg: string, msglen = -1) {.async, gcsafe.} = discard -method close(s: TestLsStream) {.async, gcsafe.} = s.closed = true +method close(s: TestLsStream) {.async, gcsafe.} = + s.isClosed = true proc newTestLsStream(ls: LsHandler): TestLsStream {.gcsafe.} = new result @@ -143,7 +145,8 @@ method write*(s: TestNaStream, msg: string, msglen = -1) {.async, gcsafe.} = if s.step == 4: await s.na(msg) -method close(s: TestNaStream) {.async, gcsafe.} = s.closed = true +method close(s: TestNaStream) {.async, gcsafe.} = + s.isClosed = true proc newTestNaStream(na: NaHandler): TestNaStream = new result diff --git a/tests/testnative.nim b/tests/testnative.nim index 02e1b7a49..b5e3e598c 100644 --- a/tests/testnative.nim +++ b/tests/testnative.nim @@ -2,5 +2,11 @@ import unittest import testvarint, testbase32, testbase58, testbase64 import testrsa, testecnist, tested25519, testsecp256k1, testcrypto import testmultibase, testmultihash, testmultiaddress, testcid, testpeer -import testtransport, testmultistream, testbufferstream, - testmplex, testidentify, testswitch, testpubsub + +import testtransport, + testmultistream, + testbufferstream, + testidentify, + testswitch, + testpubsub, + testmplex diff --git a/tests/testswitch.nim b/tests/testswitch.nim index 7e42eccee..5b71fe5fd 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -1,5 +1,5 @@ import unittest, tables, options -import chronos, chronicles +import chronos import ../libp2p/[switch, multistream, protocols/identify, @@ -36,7 +36,7 @@ method init(p: TestProto) {.gcsafe.} = suite "Switch": test "e2e use switch": - proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) = + proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) {.gcsafe.}= let seckey = PrivateKey.random(RSA) var peerInfo: PeerInfo peerInfo.peerId = some(PeerID.init(seckey)) @@ -50,7 +50,11 @@ suite "Switch": let transports = @[Transport(newTransport(TcpTransport))] let muxers = [(MplexCodec, mplexProvider)].toTable() let secureManagers = [(SecioCodec, Secure(newSecio(seckey)))].toTable() - let switch = newSwitch(peerInfo, transports, identify, muxers, secureManagers) + let switch = newSwitch(peerInfo, + transports, + identify, + muxers, + secureManagers) result = (switch, peerInfo) proc testSwitch(): Future[bool] {.async, gcsafe.} =