diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 11188184d..3f9fa36b8 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -7,36 +7,38 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import tables, sequtils +import tables, sequtils, options, strformat import chronos import connection, transport, stream/lpstream, multistream, protocol, peerinfo, multiaddress, - identify, muxers/muxer + identify, muxers/muxer, + peer type - UnableToSecureError = object of CatchableError - UnableToIdentifyError = object of CatchableError - Switch* = ref object of RootObj peerInfo*: PeerInfo connections*: TableRef[string, Connection] + muxed*: TableRef[string, Muxer] transports*: seq[Transport] protocols*: seq[LPProtocol] - muxers*: seq[MuxerProvider] + muxers*: Table[string, MuxerProvider] ms*: MultisteamSelect identity*: Identify +proc handleConn(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} + proc newSwitch*(peerInfo: PeerInfo, transports: seq[Transport], identity: Identify, - muxers: seq[MuxerProvider]): Switch = + muxers: Table[string, MuxerProvider]): Switch = new result result.peerInfo = peerInfo result.ms = newMultistream() result.transports = transports result.connections = newTable[string, Connection]() + result.muxed = newTable[string, Muxer]() result.identity = identity result.muxers = muxers @@ -46,34 +48,85 @@ proc secure(s: Switch, conn: Connection) {.async, gcsafe.} = ## secure the incoming connection discard -proc identify(s: Switch, conn: Connection, peerInfo: PeerInfo) {.async, gcsafe.} = +proc identify(s: Switch, conn: Connection) {.async, gcsafe.} = ## identify the connection - s.peerInfo.protocols = s.ms.list() # update protos before engagin in identify - await s.identity.identify(conn) + s.peerInfo.protocols = await s.ms.list(conn) # update protos before engagin in identify + let info = await s.identity.identify(conn, conn.peerInfo) -proc mux(s: Switch, conn: Connection): Future[bool] {.async, gcsafe.} = + let id = if conn.peerInfo.isSome: conn.peerInfo.get().peerId.pretty else: "" + if s.connections.contains(id): + let connection = s.connections[id] + var peerInfo = conn.peerInfo.get() + peerInfo.peerId = PeerID.init(info.pubKey) # we might not have a peerId at all + peerInfo.addrs = info.addrs + peerInfo.protocols = info.protos + +proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = ## mux incoming connection - result = true + let muxers = toSeq(s.muxers.keys) + let muxerName = await s.ms.select(conn, muxers) + if muxerName.len == 0: + return -proc handleConn(s: Switch, conn: Connection) {.async, gcsafe.} = + let muxer = s.muxers[muxerName].newMuxer(conn) + # install stream handler + muxer.streamHandler = proc (stream: Connection) {.async, gcsafe.} = + try: + asyncDiscard s.handleConn(stream) + finally: + await stream.close() + + # do identify first, so that we have a + # PeerInfo in case we didn't before + let stream = await muxer.newStream() + await s.identify(stream) + + if conn.peerInfo.isSome: + s.muxed[conn.peerInfo.get().peerId.pretty] = muxer + +proc handleConn(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = ## perform upgrade flow - try: - result = s.ms.handle(conn) # handler incoming connection - await s.secure(conn) - if await s.mux(conn): - await s.identify(conn) - finally: - await conn.close() + + # TODO: figure out proper way of handling this. + # Perhaps it's ok to discard this Future and handle + # errors elsewere? + asyncDiscard s.ms.handle(conn) # handler incoming connection + + if conn.peerInfo.isSome: + let id = conn.peerInfo.get().peerId.pretty + if s.connections.contains(id): + # if we already have a connection for this peer, + # close the incoming connection and return the + # existing one + await conn.close() + return s.connections[id] + s.connections[id] = conn + + await s.secure(conn) # secure the connection + await s.mux(conn) # mux it if possible + result = conn + +proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = + let id = if conn.peerInfo.isSome: conn.peerInfo.get().peerId.pretty else: "" + if conn.peerInfo.isSome: + if s.muxed.contains(id): + await s.muxed[id].close + + if s.connections.contains(id): + await s.connections[id].close() proc dial*(s: Switch, peer: PeerInfo, proto: string = ""): Future[Connection] {.async.} = for t in s.transports: # for each transport for a in peer.addrs: # for each address if t.handles(a): # check if it can dial it - result = await t.dial(a) - await s.secure(result) - if not await s.ms.select(result, proto): + var conn = await t.dial(a) + conn = await s.handleConn(conn) + if s.muxed.contains(peer.peerId.pretty): + conn = await s.muxed[peer.peerId.pretty].newStream() + if (await s.ms.select(conn, proto)).len == 0: raise newException(CatchableError, - "Unable to select protocol: " & proto) + &"Unable to select protocol: {proto}") + result = conn proc mount*[T: LPProtocol](s: Switch, proto: T) = if isNil(proto.handler): @@ -88,7 +141,10 @@ proc mount*[T: LPProtocol](s: Switch, proto: T) = proc start*(s: Switch) {.async.} = proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = - await s.handleConn(conn) + try: + asyncDiscard s.handleConn(conn) + except: + await s.cleanupConn(conn) for t in s.transports: # for each transport for a in s.peerInfo.addrs: diff --git a/tests/testswitch.nim b/tests/testswitch.nim index f64e89ec6..9554646ce 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -1,11 +1,13 @@ -import unittest +import unittest, tables import chronos import ../libp2p/switch, ../libp2p/multistream, ../libp2p/identify, ../libp2p/connection, ../libp2p/transport, ../libp2p/tcptransport, ../libp2p/multiaddress, ../libp2p/peerinfo, ../libp2p/crypto/crypto, ../libp2p/peer, - ../libp2p/protocol + ../libp2p/protocol, ../libp2p/muxers/muxer, + ../libp2p/muxers/mplex/mplex, + ../libp2p/muxers/mplex/types const TestCodec = "/test/proto/1.0.0" @@ -23,22 +25,33 @@ method init(p: TestProto) {.gcsafe.} = suite "Switch": test "e2e use switch": + proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) = + let seckey = PrivateKey.random(RSA) + var peerInfo: PeerInfo + peerInfo.peerId = PeerID.init(seckey) + peerInfo.addrs.add(ma) + let identify = newIdentify(peerInfo) + + proc createMplex(conn: Connection): Muxer = + result = newMplex(conn) + + let mplexProvider = newMuxerProvider(createMplex, MplexCodec) + let transports = @[Transport(newTransport(TcpTransport))] + let muxers = [(MplexCodec, mplexProvider)].toTable() + let switch = newSwitch(peerInfo, transports, identify, muxers) + result = (switch, peerInfo) + proc testSwitch(): Future[bool] {.async, gcsafe.} = let ma1: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53370") let ma2: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53371") var peerInfo1, peerInfo2: PeerInfo var switch1, switch2: Switch - proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) = - let seckey = PrivateKey.random(RSA) - var peerInfo: PeerInfo - peerInfo.peerId = PeerID.init(seckey) - peerInfo.addrs.add(ma) - let switch = newSwitch(peerInfo, @[Transport(newTransport(TcpTransport))]) - result = (switch, peerInfo) - (switch1, peerInfo1) = createSwitch(ma1) - let testProto = newProtocol(TestProto, peerInfo1) + let testProto = new TestProto + testProto.handler = proc(conn: Connection, proto: string) + {.async, gcsafe.} = discard + testProto.codec = TestCodec switch1.mount(testProto) await switch1.start()