From 92fa4110c1169b0040d36e86415ebe7ea3f4cc46 Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Wed, 18 Nov 2020 20:06:42 -0600 Subject: [PATCH] Rework transport to use chronos accept (#420) * rework transport to use the new accept api * use the new chronos primits * fixup tests to use the new transport api * handle all exceptions in upgradeIncoming * master merge * add multiaddress exception type * raise appropriate exception on invalida address * allow retrying on TransportTooManyError * adding TODO * wip * merge master * add sleep if nil is returned * accept loop handles all exceptions * avoid issues with tray/except/finally * make consistent with master * cleanup accept loop * logging * Update libp2p/transports/tcptransport.nim Co-authored-by: Jacek Sieka * use Direction enum instead of initiator flag * use consistent import style * remove experimental `closeWithEOF()` Co-authored-by: Jacek Sieka --- libp2p/daemon/daemonapi.nim | 23 ++- libp2p/multiaddress.nim | 3 + libp2p/stream/connection.nim | 2 +- libp2p/switch.nim | 127 +++++++++------- libp2p/transports/tcptransport.nim | 212 ++++++++++++++------------ libp2p/transports/transport.nim | 48 ++++-- libp2p/wire.nim | 140 ++++++++++------- tests/testidentify.nim | 56 ++++--- tests/testmplex.nim | 236 ++++++++++++++++------------- tests/testmultistream.nim | 85 ++++++----- tests/testnoise.nim | 77 ++++++---- tests/testtransport.nim | 117 +++++++++----- 12 files changed, 669 insertions(+), 457 deletions(-) diff --git a/libp2p/daemon/daemonapi.nim b/libp2p/daemon/daemonapi.nim index 8784d93..2319a0a 100644 --- a/libp2p/daemon/daemonapi.nim +++ b/libp2p/daemon/daemonapi.nim @@ -157,6 +157,8 @@ type var daemonsCount {.threadvar.}: int +chronicles.formatIt(PeerInfo): shortLog(it) + proc requestIdentity(): ProtoBuffer = ## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go ## Processing function `doIdentify(req *pb.Request)`. @@ -789,7 +791,7 @@ proc close*(api: DaemonAPI) {.async.} = pending.add(server.server.join()) await allFutures(pending) for server in api.servers: - let address = initTAddress(server.address) + let address = initTAddress(server.address).tryGet() discard tryRemoveFile($address) api.servers.setLen(0) # Closing daemon's process. @@ -800,7 +802,7 @@ proc close*(api: DaemonAPI) {.async.} = api.process.terminate() discard api.process.waitForExit() # Attempt to delete unix socket endpoint. - let address = initTAddress(api.address) + let address = initTAddress(api.address).tryGet() if address.family == AddressFamily.Unix: discard tryRemoveFile($address) @@ -1306,3 +1308,20 @@ proc pubsubSubscribe*(api: DaemonAPI, topic: string, except Exception as exc: await api.closeConnection(transp) raise exc + +proc shortLog*(pinfo: PeerInfo): string = + ## Get string representation of ``PeerInfo`` object. + result = newStringOfCap(128) + result.add("{PeerID: '") + result.add($pinfo.peer.shortLog()) + result.add("' Addresses: [") + let length = len(pinfo.addresses) + for i in 0.. 0: + result = result diff --git a/libp2p/multiaddress.nim b/libp2p/multiaddress.nim index 1edb3a2..f5a9873 100644 --- a/libp2p/multiaddress.nim +++ b/libp2p/multiaddress.nim @@ -46,6 +46,9 @@ type MaResult*[T] = Result[T, string] + MaError* = object of CatchableError + MaInvalidAddress* = object of MaError + IpTransportProtocol* = enum tcpProtocol udpProtocol diff --git a/libp2p/stream/connection.nim b/libp2p/stream/connection.nim index c9a6314..2b8c5e9 100644 --- a/libp2p/stream/connection.nim +++ b/libp2p/stream/connection.nim @@ -7,7 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import std/[hashes, oids, strformat, sugar] +import std/[hashes, oids, strformat] import chronicles, chronos, metrics import lpstream, ../multiaddress, diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 8a6380d..398fd78 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -60,6 +60,7 @@ type streamHandler*: StreamHandler secureManagers*: seq[Secure] dialLock: Table[PeerID, AsyncLock] + acceptFuts: seq[Future[void]] proc addConnEventHandler*(s: Switch, handler: ConnEventHandler, @@ -211,47 +212,50 @@ proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, g return sconn -proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = - trace "Upgrading incoming connection", conn +proc upgradeIncoming(s: Switch, incomingConn: Connection) {.async, gcsafe.} = # noraises + trace "Upgrading incoming connection", incomingConn let ms = newMultistream() # secure incoming connections - proc securedHandler (conn: Connection, - proto: string) - {.async, gcsafe, closure.} = + proc securedHandler(conn: Connection, + proto: string) + {.async, gcsafe, closure.} = trace "Starting secure handler", conn let secure = s.secureManagers.filterIt(it.codec == proto)[0] + var sconn: Connection try: - var sconn = await secure.secure(conn, false) + sconn = await secure.secure(conn, false) if isNil(sconn): return - defer: - await sconn.close() - # add the muxer for muxer in s.muxers.values: ms.addHandler(muxer.codecs, muxer) # handle subsequent secure requests await ms.handle(sconn) - - except CancelledError as exc: - raise exc except CatchableError as exc: - debug "Exception in secure handler", msg = exc.msg, conn + debug "Exception in secure handler during incoming upgrade", msg = exc.msg, conn + finally: + if not isNil(sconn): + await sconn.close() trace "Stopped secure handler", conn - if (await ms.select(conn)): # just handshake - # add the secure handlers - for k in s.secureManagers: - ms.addHandler(k.codec, securedHandler) + try: + if (await ms.select(incomingConn)): # just handshake + # add the secure handlers + for k in s.secureManagers: + ms.addHandler(k.codec, securedHandler) - # handle un-secured connections - # we handshaked above, set this ms handler as active - await ms.handle(conn, active = true) + # handle un-secured connections + # we handshaked above, set this ms handler as active + await ms.handle(incomingConn, active = true) + except CatchableError as exc: + debug "Exception upgrading incoming", exc = exc.msg + finally: + await incomingConn.close() proc internalConnect(s: Switch, peerId: PeerID, @@ -280,7 +284,7 @@ proc internalConnect(s: Switch, return conn - trace "Dialing peer", peerId + debug "Dialing peer", peerId for t in s.transports: # for each transport for a in addrs: # for each address if t.handles(a): # check if it can dial it @@ -288,10 +292,10 @@ proc internalConnect(s: Switch, let dialed = try: await t.dial(a) except CancelledError as exc: - trace "Dialing canceled", msg = exc.msg, peerId + debug "Dialing canceled", msg = exc.msg, peerId raise exc except CatchableError as exc: - trace "Dialing failed", msg = exc.msg, peerId + debug "Dialing failed", msg = exc.msg, peerId libp2p_failed_dials.inc() continue # Try the next address @@ -314,7 +318,7 @@ proc internalConnect(s: Switch, doAssert not isNil(upgraded), "connection died after upgradeOutgoing" conn = upgraded - trace "Dial successful", conn, peerInfo = conn.peerInfo + debug "Dial successful", conn, peerInfo = conn.peerInfo break finally: if lock.locked(): @@ -407,41 +411,63 @@ proc mount*[T: LPProtocol](s: Switch, proto: T, matcher: Matcher = nil) {.gcsafe s.ms.addHandler(proto.codecs, proto, matcher) +proc accept(s: Switch, transport: Transport) {.async.} = # noraises + ## transport's accept loop + ## + + while transport.running: + var conn: Connection + try: + debug "About to accept incoming connection" + conn = await transport.accept() + if not isNil(conn): + debug "Accepted an incoming connection", conn + asyncSpawn s.upgradeIncoming(conn) # perform upgrade on incoming connection + else: + # A nil connection means that we might have hit a + # file-handle limit (or another non-fatal error), + # we can get one on the next try, but we should + # be careful to not end up in a thigh loop that + # will starve the main event loop, thus we sleep + # here before retrying. + await sleepAsync(100.millis) # TODO: should be configurable? + except CatchableError as exc: + debug "Exception in accept loop, exiting", exc = exc.msg + if not isNil(conn): + await conn.close() + + return + proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = trace "starting switch for peer", peerInfo = s.peerInfo - - proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = - trace "Incoming connection", conn - try: - await s.upgradeIncoming(conn) # perform upgrade on incoming connection - except CancelledError as exc: - raise exc - except CatchableError as exc: - trace "Exception occurred in incoming handler", conn, msg = exc.msg - finally: - await conn.close() - trace "Connection handler done", conn - var startFuts: seq[Future[void]] for t in s.transports: # for each transport for i, a in s.peerInfo.addrs: if t.handles(a): # check if it handles the multiaddr - var server = await t.listen(a, handle) + var server = t.start(a) s.peerInfo.addrs[i] = t.ma # update peer's address + s.acceptFuts.add(s.accept(t)) startFuts.add(server) debug "Started libp2p node", peer = s.peerInfo - result = startFuts # listen for incoming connections + return startFuts # listen for incoming connections proc stop*(s: Switch) {.async.} = trace "Stopping switch" + for a in s.acceptFuts: + if not a.finished: + a.cancel() + + checkFutures( + await allFinished(s.acceptFuts)) + # close and cleanup all connections await s.connManager.close() for t in s.transports: try: - await t.close() + await t.stop() except CancelledError as exc: raise exc except CatchableError as exc: @@ -465,17 +491,16 @@ proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} = s.connManager.storeMuxer(muxer) try: - try: - await s.identify(muxer) - except IdentifyError as exc: - # Identify is non-essential, though if it fails, it might indicate that - # the connection was closed already - this will be picked up by the read - # loop - debug "Could not identify connection", conn, msg = exc.msg - except LPStreamClosedError as exc: - debug "Identify stream closed", conn, msg = exc.msg - except LPStreamEOFError as exc: - debug "Identify stream EOF", conn, msg = exc.msg + await s.identify(muxer) + except IdentifyError as exc: + # Identify is non-essential, though if it fails, it might indicate that + # the connection was closed already - this will be picked up by the read + # loop + debug "Could not identify connection", conn, msg = exc.msg + except LPStreamClosedError as exc: + debug "Identify stream closed", conn, msg = exc.msg + except LPStreamEOFError as exc: + debug "Identify stream EOF", conn, msg = exc.msg except CancelledError as exc: await muxer.close() raise exc diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index 7783ced..b03dccf 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -7,8 +7,8 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import oids -import chronos, chronicles, sequtils +import std/[oids, sequtils] +import chronos, chronicles import transport, ../errors, ../wire, @@ -26,10 +26,8 @@ const type TcpTransport* = ref object of Transport server*: StreamServer - clients: seq[StreamTransport] + clients: array[Direction, seq[StreamTransport]] flags: set[ServerFlags] - cleanups*: seq[Future[void]] - handlers*: seq[Future[void]] TcpTransportTracker* = ref object of TrackerBase opened*: uint64 @@ -61,132 +59,144 @@ proc setupTcpTransportTracker(): TcpTransportTracker = proc connHandler*(t: TcpTransport, client: StreamTransport, - initiator: bool): Connection = - trace "handling connection", address = $client.remoteAddress + dir: Direction): Future[Connection] {.async.} = + debug "Handling tcp connection", address = $client.remoteAddress, + dir = $dir, + clients = t.clients[Direction.In].len + + t.clients[Direction.Out].len let conn = Connection( ChronosStream.init( client, - dir = if initiator: - Direction.Out - else: - Direction.In)) + dir + )) - if not initiator: - if not isNil(t.handler): - t.handlers &= t.handler(conn) - - proc cleanup() {.async.} = + proc onClose() {.async.} = try: - await client.join() - trace "cleaning up client", addrs = $client.remoteAddress, connoid = $conn.oid - if not(isNil(conn)): - await conn.close() - t.clients.keepItIf(it != client) - except CancelledError: - # This is top-level procedure which will work as separate task, so it - # do not need to propogate CancelledError. - trace "Unexpected cancellation in transport's cleanup" - except CatchableError as exc: - trace "error cleaning up client", exc = exc.msg + await client.join() or conn.join() + trace "Cleaning up client", addrs = $client.remoteAddress, + conn - t.clients.add(client) - # All the errors are handled inside `cleanup()` procedure. - asyncSpawn cleanup() + if not(isNil(conn) and conn.closed()): + await conn.close() + + t.clients[dir].keepItIf( it != client ) + if not(isNil(client) and client.closed()): + await client.closeWait() + + trace "Cleaned up client", addrs = $client.remoteAddress, + conn + + except CatchableError as exc: + let useExc {.used.} = exc + debug "Error cleaning up client", errMsg = exc.msg, conn + + t.clients[dir].add(client) + asyncSpawn onClose() try: conn.observedAddr = MultiAddress.init(client.remoteAddress).tryGet() except CatchableError as exc: - trace "Connection setup failed", exc = exc.msg - if not(isNil(client)): - client.close() + trace "Connection setup failed", exc = exc.msg, conn + if not(isNil(client) and client.closed): + await client.closeWait() + + raise exc return conn -proc connCb(server: StreamServer, - client: StreamTransport) {.async, gcsafe.} = - trace "incoming connection", address = $client.remoteAddress - try: - let t = cast[TcpTransport](server.udata) - # we don't need result connection in this case - # as it's added inside connHandler - discard t.connHandler(client, false) - except CancelledError as exc: - raise exc - except CatchableError as err: - debug "Connection setup failed", err = err.msg - client.close() - -proc init*(T: type TcpTransport, flags: set[ServerFlags] = {}): T = +proc init*(T: type TcpTransport, + flags: set[ServerFlags] = {}): T = result = T(flags: flags) + result.initTransport() method initTransport*(t: TcpTransport) = t.multicodec = multiCodec("tcp") inc getTcpTransportTracker().opened -method close*(t: TcpTransport) {.async, gcsafe.} = - try: - ## start the transport - trace "stopping transport" - await procCall Transport(t).close() # call base - - checkFutures(await allFinished( - t.clients.mapIt(it.closeWait()))) - - # server can be nil - if not isNil(t.server): - t.server.stop() - await t.server.closeWait() - - t.server = nil - - for fut in t.handlers: - if not fut.finished: - fut.cancel() - - checkFutures( - await allFinished(t.handlers)) - t.handlers = @[] - - for fut in t.cleanups: - if not fut.finished: - fut.cancel() - - checkFutures( - await allFinished(t.cleanups)) - t.cleanups = @[] - - trace "transport stopped" - inc getTcpTransportTracker().closed - except CancelledError as exc: - raise exc - except CatchableError as exc: - trace "error shutting down tcp transport", exc = exc.msg - -method listen*(t: TcpTransport, - ma: MultiAddress, - handler: ConnHandler): - Future[Future[void]] {.async, gcsafe.} = - discard await procCall Transport(t).listen(ma, handler) # call base - +method start*(t: TcpTransport, ma: MultiAddress) {.async.} = ## listen on the transport - t.server = createStreamServer(t.ma, connCb, t.flags, t) - t.server.start() + ## + + if t.running: + trace "TCP transport already running" + return + + await procCall Transport(t).start(ma) + trace "Starting TCP transport" + + t.server = createStreamServer(t.ma, t.flags, t) # always get the resolved address in case we're bound to 0.0.0.0:0 t.ma = MultiAddress.init(t.server.sock.getLocalAddress()).tryGet() - result = t.server.join() - trace "started node on", address = t.ma + t.running = true + + trace "Listening on", address = t.ma + +method stop*(t: TcpTransport) {.async, gcsafe.} = + ## stop the transport + ## + + try: + trace "Stopping TCP transport" + await procCall Transport(t).stop() # call base + + checkFutures( + await allFinished( + t.clients[Direction.In].mapIt(it.closeWait()) & + t.clients[Direction.Out].mapIt(it.closeWait()))) + + # server can be nil + if not isNil(t.server): + await t.server.closeWait() + + t.server = nil + trace "Transport stopped" + inc getTcpTransportTracker().closed + except CatchableError as exc: + trace "Error shutting down tcp transport", exc = exc.msg + finally: + t.running = false + +template withTransportErrors(body: untyped): untyped = + try: + body + except TransportTooManyError as exc: + warn "Too many files opened", exc = exc.msg + except TransportUseClosedError as exc: + info "Server was closed", exc = exc.msg + raise newTransportClosedError(exc) + except CatchableError as exc: + trace "Unexpected error creating connection", exc = exc.msg + raise exc + +method accept*(t: TcpTransport): Future[Connection] {.async, gcsafe.} = + ## accept a new TCP connection + ## + + if not t.running: + raise newTransportClosedError() + + withTransportErrors: + let transp = await t.server.accept() + return await t.connHandler(transp, Direction.In) method dial*(t: TcpTransport, address: MultiAddress): Future[Connection] {.async, gcsafe.} = - trace "dialing remote peer", address = $address ## dial a peer - let client: StreamTransport = await connect(address) - result = t.connHandler(client, true) + ## + + trace "Dialing remote peer", address = $address + + withTransportErrors: + let transp = await connect(address) + return await t.connHandler(transp, Direction.Out) method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} = if procCall Transport(t).handles(address): - result = address.protocols.tryGet().filterIt( it == multiCodec("tcp") ).len > 0 + return address.protocols + .tryGet() + .filterIt( it == multiCodec("tcp") ) + .len > 0 diff --git a/libp2p/transports/transport.nim b/libp2p/transports/transport.nim index e34df33..a3ac045 100644 --- a/libp2p/transports/transport.nim +++ b/libp2p/transports/transport.nim @@ -15,44 +15,62 @@ import ../stream/connection, ../multicodec type - ConnHandler* = proc (conn: Connection): Future[void] {.gcsafe.} + TransportClosedError* = object of CatchableError Transport* = ref object of RootObj ma*: Multiaddress - handler*: ConnHandler multicodec*: MultiCodec + running*: bool + +proc newTransportClosedError*(parent: ref Exception = nil): ref CatchableError = + newException(TransportClosedError, + "Transport closed, no more connections!", parent) method initTransport*(t: Transport) {.base, gcsafe, locks: "unknown".} = ## perform protocol initialization + ## + discard -method close*(t: Transport) {.base, async, gcsafe.} = +method start*(t: Transport, ma: MultiAddress) {.base, async.} = + ## start the transport + ## + + t.ma = ma + trace "starting transport", address = $ma + +method stop*(t: Transport) {.base, async.} = ## stop and cleanup the transport ## including all outstanding connections + ## + discard -method listen*(t: Transport, - ma: MultiAddress, - handler: ConnHandler): - Future[Future[void]] {.base, async, gcsafe.} = - ## listen for incoming connections - t.ma = ma - t.handler = handler - trace "starting node", address = $ma +method accept*(t: Transport): Future[Connection] + {.base, async, gcsafe.} = + ## accept incoming connections + ## + + discard method dial*(t: Transport, - address: MultiAddress): - Future[Connection] {.base, async, gcsafe.} = + address: MultiAddress): Future[Connection] + {.base, async, gcsafe.} = ## dial a peer + ## + discard method upgrade*(t: Transport) {.base, async, gcsafe.} = ## base upgrade method that the transport uses to perform ## transport specific upgrades + ## + discard method handles*(t: Transport, address: MultiAddress): bool {.base, gcsafe.} = - ## check if transport supportes the multiaddress + ## check if transport supports the multiaddress + ## # by default we skip circuit addresses to avoid # having to repeat the check in every transport @@ -60,4 +78,6 @@ method handles*(t: Transport, address: MultiAddress): bool {.base, gcsafe.} = method localAddress*(t: Transport): MultiAddress {.base, gcsafe.} = ## get the local address of the transport in case started with 0.0.0.0:0 + ## + discard diff --git a/libp2p/wire.nim b/libp2p/wire.nim index 870be29..5fb4e55 100644 --- a/libp2p/wire.nim +++ b/libp2p/wire.nim @@ -8,7 +8,7 @@ ## those terms. ## This module implements wire network connection procedures. -import chronos +import chronos, stew/endians2 import multiaddress, multicodec when defined(windows): @@ -16,57 +16,65 @@ when defined(windows): else: import posix -proc initTAddress*(ma: MultiAddress): TransportAddress = +const + TRANSPMA* = mapOr( + mapAnd(IP, mapEq("udp")), + mapAnd(IP, mapEq("tcp")), + mapAnd(mapEq("unix")) + ) + + RTRANSPMA* = mapOr( + mapAnd(IP, mapEq("tcp")), + mapAnd(mapEq("unix")) + ) + +proc initTAddress*(ma: MultiAddress): MaResult[TransportAddress] {. + raises: [Defect, ResultError[string]] .} = ## Initialize ``TransportAddress`` with MultiAddress ``ma``. ## ## MultiAddress must be wire address, e.g. ``{IP4, IP6, UNIX}/{TCP, UDP}``. - var state = 0 - var pbuf: array[2, byte] - for rpart in ma.items(): - let - part = rpart.tryGet() - rcode = part.protoCode() - code = rcode.tryGet() - - if state == 0: - if code == multiCodec("ip4"): - result = TransportAddress(family: AddressFamily.IPv4) - if part.protoArgument(result.address_v4).tryGet() == 0: - raise newException(TransportAddressError, "Incorrect IPv4 address") - inc(state) - elif code == multiCodec("ip6"): - result = TransportAddress(family: AddressFamily.IPv6) - if part.protoArgument(result.address_v6).tryGet() == 0: - raise newException(TransportAddressError, "Incorrect IPv6 address") - inc(state) - elif code == multiCodec("unix"): - result = TransportAddress(family: AddressFamily.Unix) - if part.protoArgument(result.address_un).tryGet() == 0: - raise newException(TransportAddressError, "Incorrect Unix address") - result.port = Port(1) - break + if TRANSPMA.match(ma): + var pbuf: array[2, byte] + let code = ma[0].tryGet().protoCode().tryGet() + if code == multiCodec("unix"): + var res = TransportAddress(family: AddressFamily.Unix) + if ma[0].tryGet().protoArgument(res.address_un).tryGet() == 0: + err("Incorrect Unix domain address") else: - raise newException(TransportAddressError, "Could not initialize address!") - elif state == 1: - if code == multiCodec("tcp") or code == multiCodec("udp"): - if part.protoArgument(pbuf).tryGet() == 0: - raise newException(TransportAddressError, "Incorrect port") - result.port = Port((cast[uint16](pbuf[0]) shl 8) or - cast[uint16](pbuf[1])) - break + res.port = Port(1) + ok(res) + elif code == multiCodec("ip4"): + var res = TransportAddress(family: AddressFamily.IPv4) + if ma[0].tryGet().protoArgument(res.address_v4).tryGet() == 0: + err("Incorrect IPv4 address") else: - raise newException(TransportAddressError, "Could not initialize address!") + if ma[1].tryGet().protoArgument(pbuf).tryGet() == 0: + err("Incorrect port number") + else: + res.port = Port(fromBytesBE(uint16, pbuf)) + ok(res) + else: + var res = TransportAddress(family: AddressFamily.IPv6) + if ma[0].tryGet().protoArgument(res.address_v6).tryGet() == 0: + err("Incorrect IPv6 address") + else: + if ma[1].tryGet().protoArgument(pbuf).tryGet() == 0: + err("Incorrect port number") + else: + res.port = Port(fromBytesBE(uint16, pbuf)) + ok(res) + else: + err("MultiAddress must be wire address (tcp, udp or unix)") proc connect*(ma: MultiAddress, bufferSize = DefaultStreamBufferSize, child: StreamTransport = nil): Future[StreamTransport] {.async.} = ## Open new connection to remote peer with address ``ma`` and create ## new transport object ``StreamTransport`` for established connection. ## ``bufferSize`` is size of internal buffer for transport. + if not(RTRANSPMA.match(ma)): + raise newException(MaInvalidAddress, "Incorrect or unsupported address!") - let address = initTAddress(ma) - if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}: - if ma[1].tryGet().protoCode().tryGet() != multiCodec("tcp"): - raise newException(TransportAddressError, "Incorrect address type!") + let address = initTAddress(ma).tryGet() result = await connect(address, bufferSize, child) proc createStreamServer*[T](ma: MultiAddress, @@ -79,11 +87,27 @@ proc createStreamServer*[T](ma: MultiAddress, child: StreamServer = nil, init: TransportInitCallback = nil): StreamServer = ## Create new TCP stream server which bounds to ``ma`` address. - var address = initTAddress(ma) - if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}: - if ma[1].tryGet().protoCode().tryGet() != multiCodec("tcp"): - raise newException(TransportAddressError, "Incorrect address type!") - result = createStreamServer(address, cbproc, flags, udata, sock, backlog, + if not(RTRANSPMA.match(ma)): + raise newException(MaInvalidAddress, "Incorrect or unsupported address!") + + let address = initTAddress(ma) + result = createStreamServer(address.tryGet(), cbproc, flags, udata, sock, + backlog, bufferSize, child, init) + +proc createStreamServer*[T](ma: MultiAddress, + flags: set[ServerFlags] = {}, + udata: ref T, + sock: AsyncFD = asyncInvalidSocket, + backlog: int = 100, + bufferSize: int = DefaultStreamBufferSize, + child: StreamServer = nil, + init: TransportInitCallback = nil): StreamServer = + ## Create new TCP stream server which bounds to ``ma`` address. + if not(RTRANSPMA.match(ma)): + raise newException(MaInvalidAddress, "Incorrect or unsupported address!") + + let address = initTAddress(ma) + result = createStreamServer(address.tryGet(), flags, udata, sock, backlog, bufferSize, child, init) proc createAsyncSocket*(ma: MultiAddress): AsyncFD = @@ -91,16 +115,18 @@ proc createAsyncSocket*(ma: MultiAddress): AsyncFD = ## protocol information. ## ## Returns ``asyncInvalidSocket`` on error. + ## + ## Note: This procedure only used in `go-libp2p-daemon` wrapper. var socktype: SockType = SockType.SOCK_STREAM protocol: Protocol = Protocol.IPPROTO_TCP - address: TransportAddress - try: - address = initTAddress(ma) - except: + let maddr = initTAddress(ma) + if maddr.isErr(): return asyncInvalidSocket + let address = maddr.tryGet() + if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}: if ma[1].tryGet().protoCode().tryGet() == multiCodec("udp"): socktype = SockType.SOCK_DGRAM @@ -117,22 +143,28 @@ proc createAsyncSocket*(ma: MultiAddress): AsyncFD = proc bindAsyncSocket*(sock: AsyncFD, ma: MultiAddress): bool = ## Bind socket ``sock`` to MultiAddress ``ma``. + ## + ## Note: This procedure only used in `go-libp2p-daemon` wrapper. var saddr: Sockaddr_storage slen: SockLen - address: TransportAddress - try: - address = initTAddress(ma) - except: + + let maddr = initTAddress(ma) + if maddr.isErr(): return false + + let address = maddr.tryGet() toSAddr(address, saddr, slen) - if bindSocket(SocketHandle(sock), cast[ptr SockAddr](addr saddr), slen) == 0: + if bindSocket(SocketHandle(sock), cast[ptr SockAddr](addr saddr), + slen) == 0: result = true else: result = false proc getLocalAddress*(sock: AsyncFD): TransportAddress = ## Retrieve local socket ``sock`` address. + ## + ## Note: This procedure only used in `go-libp2p-daemon` wrapper. var saddr: Sockaddr_storage var slen = SockLen(sizeof(Sockaddr_storage)) diff --git a/tests/testidentify.nim b/tests/testidentify.nim index 89e8736..f57355c 100644 --- a/tests/testidentify.nim +++ b/tests/testidentify.nim @@ -20,21 +20,22 @@ suite "Identify": asyncTest "handle identify message": let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() let remoteSecKey = PrivateKey.random(ECDSA, rng[]).get() - let remotePeerInfo = PeerInfo.init(remoteSecKey, - [ma], - ["/test/proto1/1.0.0", - "/test/proto2/1.0.0"]) + let remotePeerInfo = PeerInfo.init( + remoteSecKey, [ma], ["/test/proto1/1.0.0", "/test/proto2/1.0.0"]) var serverFut: Future[void] let identifyProto1 = newIdentify(remotePeerInfo) let msListen = newMultistream() msListen.addHandler(IdentifyCodec, identifyProto1) - proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = - await msListen.handle(conn) var transport1 = TcpTransport.init() - serverFut = await transport1.listen(ma, connHandler) + serverFut = transport1.start(ma) + proc acceptHandler(): Future[void] {.async, gcsafe.} = + let conn = await transport1.accept() + await msListen.handle(conn) + + let acceptFut = acceptHandler() let msDial = newMultistream() let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) @@ -51,9 +52,10 @@ suite "Identify": check id.protos == @["/test/proto1/1.0.0", "/test/proto2/1.0.0"] await conn.close() - await transport1.close() + await acceptFut + await transport1.stop() await serverFut - await transport2.close() + await transport2.stop() asyncTest "handle failed identify": let ma = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() @@ -61,17 +63,22 @@ suite "Identify": let identifyProto1 = newIdentify(remotePeerInfo) let msListen = newMultistream() - let done = newFuture[void]() - msListen.addHandler(IdentifyCodec, identifyProto1) - proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = - await msListen.handle(conn) - await conn.close() - done.complete() let transport1: TcpTransport = TcpTransport.init() - asyncCheck transport1.listen(ma, connHandler) + asyncCheck transport1.start(ma) + proc acceptHandler() {.async.} = + var conn: Connection + try: + conn = await transport1.accept() + await msListen.handle(conn) + except CatchableError: + discard + finally: + await conn.close() + + let acceptFut = acceptHandler() let msDial = newMultistream() let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) @@ -79,12 +86,11 @@ suite "Identify": let identifyProto2 = newIdentify(localPeerInfo) expect IdentityNoMatchError: - try: - let pi2 = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get()) - discard await msDial.select(conn, IdentifyCodec) - discard await identifyProto2.identify(conn, pi2) - finally: - await done.wait(5000.millis) # when no issues will not wait that long! - await conn.close() - await transport2.close() - await transport1.close() + let pi2 = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get()) + discard await msDial.select(conn, IdentifyCodec) + discard await identifyProto2.identify(conn, pi2) + + await conn.close() + await acceptFut.wait(5000.millis) # when no issues will not wait that long! + await transport2.stop() + await transport1.stop() diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 9e6dc6f..0da8761 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -326,22 +326,22 @@ suite "Mplex": asyncTest "read/write receiver": let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() - var done = newFuture[void]() - proc connHandler(conn: Connection) {.async, gcsafe.} = + let transport1: TcpTransport = TcpTransport.init() + let listenFut = transport1.start(ma) + + proc acceptHandler() {.async, gcsafe.} = + let conn = await transport1.accept() let mplexListen = Mplex.init(conn) mplexListen.streamHandler = proc(stream: Connection) {.async, gcsafe.} = let msg = await stream.readLp(1024) check string.fromBytes(msg) == "HELLO" await stream.close() - done.complete() await mplexListen.handle() await mplexListen.close() - let transport1: TcpTransport = TcpTransport.init() - let listenFut = await transport1.listen(ma, connHandler) - + let acceptFut = acceptHandler() let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) @@ -352,34 +352,33 @@ suite "Mplex": check LPChannel(stream).isOpen # not lazy await stream.close() - await done.wait(1.seconds) await conn.close() + await acceptFut.wait(1.seconds) await mplexDialFut.wait(1.seconds) await allFuturesThrowing( - transport1.close(), - transport2.close()) - + transport1.stop(), + transport2.stop()) await listenFut asyncTest "read/write receiver lazy": let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() - var done = newFuture[void]() - proc connHandler(conn: Connection) {.async, gcsafe.} = + let transport1: TcpTransport = TcpTransport.init() + let listenFut = transport1.start(ma) + + proc acceptHandler() {.async, gcsafe.} = + let conn = await transport1.accept() let mplexListen = Mplex.init(conn) mplexListen.streamHandler = proc(stream: Connection) {.async, gcsafe.} = let msg = await stream.readLp(1024) check string.fromBytes(msg) == "HELLO" await stream.close() - done.complete() await mplexListen.handle() await mplexListen.close() - let transport1: TcpTransport = TcpTransport.init() - let listenFut = await transport1.listen(ma, connHandler) - + let acceptFut = acceptHandler() let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) @@ -391,12 +390,12 @@ suite "Mplex": check LPChannel(stream).isOpen # assert lazy await stream.close() - await done.wait(1.seconds) await conn.close() + await acceptFut.wait(1.seconds) await mplexDialFut await allFuturesThrowing( - transport1.close(), - transport2.close()) + transport1.stop(), + transport2.stop()) await listenFut asyncTest "write fragmented": @@ -408,8 +407,12 @@ suite "Mplex": for _ in 0..