diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 3d48327..38d3609 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -9,19 +9,22 @@ import tables, sequtils, options, strformat import chronos, chronicles -import connection, - transports/transport, - stream/lpstream, - multistream, +import connection, + transports/transport, + stream/lpstream, + multistream, protocols/protocol, protocols/secure/secure, # for plain text - peerinfo, + peerinfo, multiaddress, - protocols/identify, + protocols/identify, + protocols/pubsub/pubsub, muxers/muxer, peer type + NoPubSubException = object of CatchableError + Switch* = ref object of RootObj peerInfo*: PeerInfo connections*: TableRef[string, Connection] @@ -33,6 +36,10 @@ type identity*: Identify streamHandler*: StreamHandler secureManagers*: seq[Secure] + pubSub*: Option[PubSub] + +proc newNoPubSubException(): ref Exception {.inline.} = + result = newException(NoPubSubException, "no pubsub provided!") proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = ## secure the incoming connection @@ -55,17 +62,25 @@ proc identify(s: Switch, conn: Connection) {.async, gcsafe.} = if (await s.ms.select(conn, s.identity.codec)): let info = await s.identity.identify(conn, conn.peerInfo) - let id = if conn.peerInfo.isSome: - conn.peerInfo.get().peerId.pretty - else: + let id = if conn.peerInfo.peerId.isSome: + conn.peerInfo.peerId.get().pretty + else: "" + if id.len > 0 and s.connections.contains(id): let connection = s.connections[id] - var peerInfo = conn.peerInfo.get() - peerInfo.peerId = PeerID.init(info.pubKey) # we might not have a peerId at all - peerInfo.addrs = info.addrs - peerInfo.protocols = info.protos - debug "identify: identified remote peer ", peer = peerInfo.peerId.pretty + var peerInfo = conn.peerInfo + + if info.pubKey.isSome: + peerInfo.peerId = some(PeerID.init(info.pubKey.get())) # we might not have a peerId at all + + if info.addrs.len > 0: + peerInfo.addrs = info.addrs + + if info.protos.len > 0: + peerInfo.protocols = info.protos + + debug "identify: identified remote peer ", peer = peerInfo.peerId.get().pretty except IdentityInvalidMsgError as exc: debug "identify: invalid message", msg = exc.msg except IdentityNoMatchError as exc: @@ -90,8 +105,7 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = # add muxer handler cleanup proc handlerFut.addCallback( proc(udata: pointer = nil) {.gcsafe.} = - if handlerFut.finished: - debug "mux: Muxer handler completed for peer ", peer = conn.peerInfo.get().peerId.pretty + debug "mux: Muxer handler completed for peer ", peer = conn.peerInfo.peerId.get().pretty ) await s.identify(stream) await stream.close() # close idenity stream @@ -101,14 +115,14 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = # on exit even if there is no peer for it. This shouldn't # happen once secio is in place, but still something to keep # in mind - if conn.peerInfo.isSome: - s.muxed[conn.peerInfo.get().peerId.pretty] = muxer + if conn.peerInfo.peerId.isSome: + s.muxed[conn.peerInfo.peerId.get().pretty] = muxer proc handleConn(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = result = conn ## perform upgrade flow - if result.peerInfo.isSome: - let id = result.peerInfo.get().peerId.pretty + if result.peerInfo.peerId.isSome: + let id = result.peerInfo.peerId.get().pretty if s.connections.contains(id): # if we already have a connection for this peer, # close the incoming connection and return the @@ -121,14 +135,22 @@ proc handleConn(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe await s.mux(result) # mux it if possible proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = - if conn.peerInfo.isSome: - let id = conn.peerInfo.get().peerId.pretty + if conn.peerInfo.peerId.isSome: + let id = conn.peerInfo.peerId.get().pretty if s.muxed.contains(id): await s.muxed[id].close if s.connections.contains(id): await s.connections[id].close() +proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Option[Connection]] {.async, gcsafe.} = + # if there is a muxer for the connection + # use it instead to create a muxed stream + if s.muxed.contains(peerInfo.peerId.get().pretty): + let muxer = s.muxed[peerInfo.peerId.get().pretty] + let conn = await muxer.newStream() + result = some(conn) + proc dial*(s: Switch, peer: PeerInfo, proto: string = ""): @@ -137,13 +159,13 @@ proc dial*(s: Switch, for a in peer.addrs: # for each address if t.handles(a): # check if it can dial it result = await t.dial(a) - result.peerInfo = some(peer) + # make sure to assign the peer to the connection + result.peerInfo = peer result = await s.handleConn(result) - # if there is a muxer for the connection - # use it instead to create a muxed stream - if s.muxed.contains(peer.peerId.pretty): - result = await s.muxed[peer.peerId.pretty].newStream() + let stream = await s.getMuxedStream(peer) + if stream.isSome: + result = stream.get() debug "dial: attempting to select remote ", proto = proto if not (await s.ms.select(result, proto)): @@ -178,11 +200,35 @@ proc start*(s: Switch) {.async.} = proc stop*(s: Switch) {.async.} = await allFutures(s.transports.mapIt(it.close())) +proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = + if s.pubSub.isSome: + let conn = await s.dial(peerInfo, s.pubSub.get().codec) + await s.pubSub.get().subscribeToPeer(conn) + +proc subscribe*(s: Switch, topic: string, handler: TopicHandler): Future[void] {.gcsafe.} = + if s.pubSub.isNone: + raise newNoPubSubException() + + result = s.pubSub.get().subscribe(topic, handler) + +proc unsubscribe*(s: Switch, topics: seq[string]): Future[void] {.gcsafe.} = + if s.pubSub.isNone: + raise newNoPubSubException() + + result = s.pubSub.get().unsubscribe(topics) + +proc publish*(s: Switch, topic: string, data: seq[byte]): Future[void] {.gcsafe.} = + if s.pubSub.isNone: + raise newNoPubSubException() + + result = s.pubSub.get().publish(topic, data) + proc newSwitch*(peerInfo: PeerInfo, transports: seq[Transport], identity: Identify, muxers: Table[string, MuxerProvider], - secureManagers: seq[Secure] = @[]): Switch = + secureManagers: seq[Secure] = @[], + pubSub: Option[PubSub] = none(PubSub)): Switch = new result result.peerInfo = peerInfo result.ms = newMultistream() @@ -216,3 +262,7 @@ proc newSwitch*(peerInfo: PeerInfo, result.secureManagers.add(manager) result.secureManagers = result.secureManagers.deduplicate() + + if pubSub.isSome: + result.pubSub = pubSub + result.mount(pubSub.get())