diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index b0cd263aa..0cb57ad2c 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -57,11 +57,11 @@ proc newStreamInternal*(m: Mplex, m.getChannelList(initiator)[id] = result method handle*(m: Mplex) {.async, gcsafe.} = + debug "starting mplex main loop" try: while not m.connection.closed: let msgRes = await m.connection.readMsg() if msgRes.isNone: - await sleepAsync(100.millis) continue let (id, msgType, data) = msgRes.get() diff --git a/libp2p/muxers/muxer.nim b/libp2p/muxers/muxer.nim index 5b77405d7..aa466ea2e 100644 --- a/libp2p/muxers/muxer.nim +++ b/libp2p/muxers/muxer.nim @@ -62,8 +62,7 @@ method init(c: MuxerProvider) = proc handler(conn: Connection, proto: string) {.async, gcsafe, closure.} = let muxer = c.newMuxer(conn) if not isNil(c.muxerHandler): - debug "CALLING MUXER HANDLER" - await c.muxerHandler(muxer) + asyncCheck c.muxerHandler(muxer) if not isNil(c.streamHandler): muxer.streamHandler = c.streamHandler diff --git a/libp2p/protocols/identify.nim b/libp2p/protocols/identify.nim index bdd9955a5..92839a7f0 100644 --- a/libp2p/protocols/identify.nim +++ b/libp2p/protocols/identify.nim @@ -103,6 +103,7 @@ proc newIdentify*(peerInfo: PeerInfo): Identify = method init*(p: Identify) = proc handle(conn: Connection, proto: string) {.async, gcsafe, closure.} = + debug "handling identify request" var pb = encodeMsg(p.peerInfo, await conn.getObservedAddrs()) await conn.writeLp(pb.buffer) diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 675dc11f9..f3707f0a3 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -90,6 +90,7 @@ proc identify*(s: Switch, conn: Connection) {.async, gcsafe.} = debug "identify: peer's public keys don't match ", msg = exc.msg proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = + debug "muxing connection" ## mux incoming connection let muxers = toSeq(s.muxers.keys) let muxerName = await s.ms.select(conn, muxers) @@ -122,7 +123,8 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = if conn.peerInfo.peerId.isSome: s.muxed[conn.peerInfo.peerId.get().pretty] = muxer -proc handleConn(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = +proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = + debug "handling connection", conn = conn result = conn ## perform upgrade flow if result.peerInfo.peerId.isSome: @@ -165,7 +167,7 @@ proc dial*(s: Switch, result = await t.dial(a) # make sure to assign the peer to the connection result.peerInfo = peer - result = await s.handleConn(result) + result = await s.upgradeOutgoing(result) let stream = await s.getMuxedStream(peer) if stream.isSome: @@ -188,11 +190,23 @@ proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} = s.ms.addHandler(proto.codec, proto) -proc start*(s: Switch): Future[seq[Future[void]]] {.async.} = +proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = + let ms = newMultistream() + if (await ms.select(conn)): # just handshake + for secure in s.secureManagers: + ms.addHandler(secure.codec, secure) + + await ms.handle(conn) + + for muxer in s.muxers.values: + ms.addHandler(muxer.codec, muxer) + + await ms.handle(conn) + +proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = try: - if (await s.ms.select(conn)): # just handshake - await s.ms.handle(conn) # handle incoming connection + await s.upgradeIncoming(conn) # perform upgrade on incoming connection except: await s.cleanupConn(conn) @@ -262,17 +276,13 @@ proc newSwitch*(peerInfo: PeerInfo, let stream = await muxer.newStream() await s.identify(stream) - result.mount(val) - for s in secureManagers.deduplicate(): debug "adding secure manager ", codec = s.codec result.secureManagers.add(s) - result.mount(s) if result.secureManagers.len == 0: # use plain text if no secure managers are provided let manager = Secure(newPlainText()) - result.mount(manager) result.secureManagers.add(manager) result.secureManagers = result.secureManagers.deduplicate() diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 31c11c594..49d58abb6 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -132,7 +132,7 @@ suite "Mplex": await mplexListen.handle() let transport1: TcpTransport = newTransport(TcpTransport) - await transport1.listen(ma, connHandler) + asyncCheck transport1.listen(ma, connHandler) let transport2: TcpTransport = newTransport(TcpTransport) let conn = await transport2.dial(ma) @@ -164,7 +164,7 @@ suite "Mplex": await mplexListen.handle() let transport1: TcpTransport = newTransport(TcpTransport) - await transport1.listen(ma, connHandler) + asyncCheck transport1.listen(ma, connHandler) let transport2: TcpTransport = newTransport(TcpTransport) let conn = await transport2.dial(ma) @@ -199,7 +199,7 @@ suite "Mplex": debug "handle completed") let transport1: TcpTransport = newTransport(TcpTransport) - await transport1.listen(ma, connHandler) + asyncCheck transport1.listen(ma, connHandler) let transport2: TcpTransport = newTransport(TcpTransport) let conn = await transport2.dial(ma) @@ -240,7 +240,7 @@ suite "Mplex": = debug "completed listener") let transport1: TcpTransport = newTransport(TcpTransport) - await transport1.listen(ma, connHandler) + asyncCheck transport1.listen(ma, connHandler) let transport2: TcpTransport = newTransport(TcpTransport) let conn = await transport2.dial(ma) diff --git a/tests/testmultistream.nim b/tests/testmultistream.nim index 68c74a2d5..c6f139f26 100644 --- a/tests/testmultistream.nim +++ b/tests/testmultistream.nim @@ -1,4 +1,4 @@ -import unittest, strutils, sequtils, sugar, strformat +import unittest, strutils, sequtils, sugar, strformat, options import chronos import ../libp2p/connection, ../libp2p/multistream, @@ -165,7 +165,7 @@ suite "Multistream select": let seckey = PrivateKey.random(RSA) var peerInfo: PeerInfo - peerInfo.peerId = PeerID.init(seckey) + peerInfo.peerId = some(PeerID.init(seckey)) var protocol: LPProtocol = new LPProtocol proc testHandler(conn: Connection, proto: string): @@ -195,7 +195,7 @@ suite "Multistream select": let seckey = PrivateKey.random(RSA) var peerInfo: PeerInfo - peerInfo.peerId = PeerID.init(seckey) + peerInfo.peerId = some(PeerID.init(seckey)) var protocol: LPProtocol = new LPProtocol proc testHandler(conn: Connection, proto: string): @@ -222,7 +222,7 @@ suite "Multistream select": let seckey = PrivateKey.random(RSA) var peerInfo: PeerInfo - peerInfo.peerId = PeerID.init(seckey) + peerInfo.peerId = some(PeerID.init(seckey)) var protocol: LPProtocol = new LPProtocol proc testHandler(conn: Connection, proto: string): @@ -242,7 +242,7 @@ suite "Multistream select": let seckey = PrivateKey.random(RSA) var peerInfo: PeerInfo - peerInfo.peerId = PeerID.init(seckey) + peerInfo.peerId = some(PeerID.init(seckey)) var protocol: LPProtocol = new LPProtocol proc testHandler(conn: Connection, proto: string): @@ -259,7 +259,7 @@ suite "Multistream select": await msListen.handle(conn) let transport1: TcpTransport = newTransport(TcpTransport) - await transport1.listen(ma, connHandler) + asyncCheck transport1.listen(ma, connHandler) let msDial = newMultistream() let transport2: TcpTransport = newTransport(TcpTransport) @@ -281,7 +281,7 @@ suite "Multistream select": let msListen = newMultistream() let seckey = PrivateKey.random(RSA) var peerInfo: PeerInfo - peerInfo.peerId = PeerID.init(seckey) + peerInfo.peerId = some(PeerID.init(seckey)) var protocol: LPProtocol = new LPProtocol protocol.handler = proc(conn: Connection, proto: string) {.async, gcsafe.} = await conn.close() @@ -295,7 +295,7 @@ suite "Multistream select": let transport1: TcpTransport = newTransport(TcpTransport) proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = await msListen.handle(conn) - await transport1.listen(ma, connHandler) + asyncCheck transport1.listen(ma, connHandler) let msDial = newMultistream() let transport2: TcpTransport = newTransport(TcpTransport) @@ -315,7 +315,7 @@ suite "Multistream select": let seckey = PrivateKey.random(RSA) var peerInfo: PeerInfo - peerInfo.peerId = PeerID.init(seckey) + peerInfo.peerId = some(PeerID.init(seckey)) var protocol: LPProtocol = new LPProtocol proc testHandler(conn: Connection, proto: string): @@ -332,7 +332,7 @@ suite "Multistream select": await msListen.handle(conn) let transport1: TcpTransport = newTransport(TcpTransport) - await transport1.listen(ma, connHandler) + asyncCheck transport1.listen(ma, connHandler) let msDial = newMultistream() let transport2: TcpTransport = newTransport(TcpTransport) @@ -354,7 +354,7 @@ suite "Multistream select": let seckey = PrivateKey.random(RSA) var peerInfo: PeerInfo - peerInfo.peerId = PeerID.init(seckey) + peerInfo.peerId = some(PeerID.init(seckey)) var protocol: LPProtocol = new LPProtocol proc testHandler(conn: Connection, proto: string): @@ -371,7 +371,7 @@ suite "Multistream select": await msListen.handle(conn) let transport1: TcpTransport = newTransport(TcpTransport) - await transport1.listen(ma, connHandler) + asyncCheck transport1.listen(ma, connHandler) let msDial = newMultistream() let transport2: TcpTransport = newTransport(TcpTransport) diff --git a/tests/testswitch.nim b/tests/testswitch.nim index 4eff82946..96681dea0 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -1,4 +1,4 @@ -import unittest, tables +import unittest, tables, options import chronos, chronicles import ../libp2p/switch, ../libp2p/multistream, @@ -34,7 +34,7 @@ suite "Switch": proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) = let seckey = PrivateKey.random(RSA) var peerInfo: PeerInfo - peerInfo.peerId = PeerID.init(seckey) + peerInfo.peerId = some(PeerID.init(seckey)) peerInfo.addrs.add(ma) let identify = newIdentify(peerInfo) @@ -59,17 +59,17 @@ suite "Switch": testProto.init() testProto.codec = TestCodec switch1.mount(testProto) - await switch1.start() + asyncCheck switch1.start() (switch2, peerInfo2) = createSwitch(ma2) - await switch2.start() + asyncCheck switch2.start() let conn = await switch2.dial(peerInfo1, TestCodec) await conn.writeLp("Hello!") let msg = cast[string](await conn.readLp()) check "Hello!" == msg + # await allFutures(switch1.stop(), switch2.stop()) result = true - await allFutures(switch1.stop(), switch2.stop()) check: waitFor(testSwitch()) == true