diff --git a/libp2p/switch.nim b/libp2p/switch.nim index fd6858c20..d6dddf899 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -9,11 +9,16 @@ import tables, sequtils, options, strformat import chronos -import connection, transport, +import connection, + transports/transport, stream/lpstream, - multistream, protocol, - peerinfo, multiaddress, - identify, muxers/muxer, + multistream, + protocols/protocol, + protocols/secure, + peerinfo, + multiaddress, + protocols/identify, + muxers/muxer, peer type @@ -26,63 +31,55 @@ type muxers*: Table[string, MuxerProvider] ms*: MultisteamSelect identity*: Identify + streamHandler*: StreamHandler + secureManager*: Secure -proc handleConn(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} - -proc newSwitch*(peerInfo: PeerInfo, - transports: seq[Transport], - identity: Identify, - muxers: Table[string, MuxerProvider]): Switch = - new result - result.peerInfo = peerInfo - result.ms = newMultistream() - result.transports = transports - result.connections = newTable[string, Connection]() - result.muxed = newTable[string, Muxer]() - result.identity = identity - result.muxers = muxers - - result.ms.addHandler(IdentifyCodec, identity) - -proc secure(s: Switch, conn: Connection) {.async, gcsafe.} = +proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = ## secure the incoming connection - discard + + # plaintext for now, doesn't do anything + if not (await s.ms.select(conn, s.secureManager.codec)): + raise newException(CatchableError, "Unable to negotiate a secure channel!") + + result = conn proc identify(s: Switch, conn: Connection) {.async, gcsafe.} = ## identify the connection - s.peerInfo.protocols = await s.ms.list(conn) # update protos before engagin in identify - let info = await s.identity.identify(conn, conn.peerInfo) + # s.peerInfo.protocols = await s.ms.list(conn) # update protos before engagin in identify + try: + if (await s.ms.select(conn, s.identity.codec)): + let info = await s.identity.identify(conn, conn.peerInfo) - let id = if conn.peerInfo.isSome: conn.peerInfo.get().peerId.pretty else: "" - if s.connections.contains(id): - let connection = s.connections[id] - var peerInfo = conn.peerInfo.get() - peerInfo.peerId = PeerID.init(info.pubKey) # we might not have a peerId at all - peerInfo.addrs = info.addrs - peerInfo.protocols = info.protos + let id = if conn.peerInfo.isSome: + conn.peerInfo.get().peerId.pretty + else: + "" + if id.len > 0 and s.connections.contains(id): + let connection = s.connections[id] + var peerInfo = conn.peerInfo.get() + peerInfo.peerId = PeerID.init(info.pubKey) # we might not have a peerId at all + peerInfo.addrs = info.addrs + peerInfo.protocols = info.protos + except IdentityInvalidMsgError as exc: + echo exc.msg # TODO: Loging + except IdentityNoMatchError as exc: + echo exc.msg # TODO: Loging -proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = +proc mux(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = ## mux incoming connection let muxers = toSeq(s.muxers.keys) let muxerName = await s.ms.select(conn, muxers) - if muxerName.len == 0: + if muxerName.len == 0 or muxerName == "na": return let muxer = s.muxers[muxerName].newMuxer(conn) # install stream handler - muxer.streamHandler = proc (stream: Connection) {.async, gcsafe.} = - try: - # TODO: figure out proper way of handling this. - # Perhaps it's ok to discard this Future and handle - # errors elsewere? - asyncCheck s.ms.handle(stream) # handle incoming connection - finally: - await stream.close() + muxer.streamHandler = s.streamHandler # do identify first, so that we have a # PeerInfo in case we didn't before - let stream = await muxer.newStream() - await s.identify(stream) + result = await muxer.newStream() + # await s.identify(result) # store it in muxed connections if we have a peer for it # TODO: We should make sure that this are cleaned up properly @@ -93,20 +90,20 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = s.muxed[conn.peerInfo.get().peerId.pretty] = muxer proc handleConn(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = + result = conn ## perform upgrade flow - if conn.peerInfo.isSome: - let id = conn.peerInfo.get().peerId.pretty + if result.peerInfo.isSome: + let id = result.peerInfo.get().peerId.pretty if s.connections.contains(id): # if we already have a connection for this peer, # close the incoming connection and return the # existing one - await conn.close() + await result.close() return s.connections[id] - s.connections[id] = conn + s.connections[id] = result - await s.secure(conn) # secure the connection - # await s.mux(conn) # mux it if possible - result = conn + result = await s.secure(conn) # secure the connection + result = await s.mux(result) # mux it if possible proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = let id = if conn.peerInfo.isSome: conn.peerInfo.get().peerId.pretty else: "" @@ -124,21 +121,21 @@ proc dial*(s: Switch, for t in s.transports: # for each transport for a in peer.addrs: # for each address if t.handles(a): # check if it can dial it - var conn = await t.dial(a) - conn = await s.handleConn(conn) + result = await t.dial(a) + result.peerInfo = some(peer) + result = await s.handleConn(result) if s.muxed.contains(peer.peerId.pretty): - conn = await s.muxed[peer.peerId.pretty].newStream() - if (await s.ms.select(conn, proto)).len == 0: + result = await s.muxed[peer.peerId.pretty].newStream() + if (await s.ms.select(result, proto)): raise newException(CatchableError, &"Unable to select protocol: {proto}") - result = conn -proc mount*[T: LPProtocol](s: Switch, proto: T) = +proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} = if isNil(proto.handler): raise newException(CatchableError, "Protocol has to define a handle method or proc") - if len(proto.codec) <= 0: + if proto.codec.len == 0: raise newException(CatchableError, "Protocol has to define a codec string") @@ -147,10 +144,8 @@ proc mount*[T: LPProtocol](s: Switch, proto: T) = proc start*(s: Switch) {.async.} = proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = try: - # TODO: figure out proper way of handling this. - # Perhaps it's ok to discard this Future and handle - # errors elsewere? - asyncCheck s.ms.handle(conn) # handle incoming connection + if (await s.ms.select(conn)): + await s.ms.handle(conn) # handle incoming connection except: await s.cleanupConn(conn) @@ -161,3 +156,31 @@ proc start*(s: Switch) {.async.} = proc stop*(s: Switch) {.async.} = await allFutures(s.transports.mapIt(it.close())) + +proc newSwitch*(peerInfo: PeerInfo, + transports: seq[Transport], + identity: Identify, + muxers: Table[string, MuxerProvider]): Switch = + new result + result.peerInfo = peerInfo + result.ms = newMultistream() + result.transports = transports + result.connections = newTable[string, Connection]() + result.muxed = newTable[string, Muxer]() + result.identity = identity + result.muxers = muxers + + let s = result # can't capture result + result.streamHandler = proc(stream: Connection) {.async, gcsafe.} = + # TODO: figure out proper way of handling this. + # Perhaps it's ok to discard this Future and handle + # errors elsewere? + await s.ms.handle(stream) # handle incoming connection + + result.mount(identity) + for key, val in muxers: + val.streamHandler = result.streamHandler + result.mount(val) + + result.secureManager = Secure(newPlainText()) + result.mount(result.secureManager) diff --git a/tests/testswitch.nim b/tests/testswitch.nim index cee417221..29ba714fb 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -1,11 +1,16 @@ import unittest, tables import chronos -import ../libp2p/switch, ../libp2p/multistream, - ../libp2p/identify, ../libp2p/connection, - ../libp2p/transport, ../libp2p/tcptransport, - ../libp2p/multiaddress, ../libp2p/peerinfo, - ../libp2p/crypto/crypto, ../libp2p/peer, - ../libp2p/protocol, ../libp2p/muxers/muxer, +import ../libp2p/switch, + ../libp2p/multistream, + ../libp2p/protocols/identify, + ../libp2p/connection, + ../libp2p/transports/[transport, tcptransport], + ../libp2p/multiaddress, + ../libp2p/peerinfo, + ../libp2p/crypto/crypto, + ../libp2p/peer, + ../libp2p/protocols/protocol, + ../libp2p/muxers/muxer, ../libp2p/muxers/mplex/mplex, ../libp2p/muxers/mplex/types @@ -17,6 +22,7 @@ type method init(p: TestProto) {.gcsafe.} = proc handle(conn: Connection, proto: string) {.async, gcsafe.} = let msg = cast[string](await conn.readLp()) + echo msg check "Hello!" == msg await conn.writeLp("Hello!") @@ -43,7 +49,7 @@ suite "Switch": proc testSwitch(): Future[bool] {.async, gcsafe.} = let ma1: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53370") - let ma2: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53381") + let ma2: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53371") var peerInfo1, peerInfo2: PeerInfo var switch1, switch2: Switch @@ -58,11 +64,17 @@ suite "Switch": (switch2, peerInfo2) = createSwitch(ma2) await switch2.start() let conn = await switch2.dial(peerInfo1, TestCodec) + echo "DIALED???" + echo conn.repr await conn.writeLp("Hello!") + echo "WROTE FROM TEST" + echo conn.repr let msg = cast[string](await conn.readLp()) + echo msg check "Hello!" == msg result = true + await allFutures(switch1.stop(), switch2.stop()) check: waitFor(testSwitch()) == true