From 8d5ea43e2bbd9efdb34703fe6249efe2a88cc00f Mon Sep 17 00:00:00 2001 From: Tanguy Date: Wed, 8 Mar 2023 12:30:19 +0100 Subject: [PATCH] Upgrade flow refactoring (#807) --- config.nims | 1 + libp2p/builders.nim | 9 +- libp2p/connmanager.nim | 308 ++++++------------ libp2p/dialer.nim | 88 +++-- libp2p/multistream.nim | 159 +++++---- libp2p/muxers/muxer.nim | 47 +-- libp2p/muxers/yamux/yamux.nim | 2 + libp2p/peerstore.nim | 38 ++- .../protocols/connectivity/autonat/client.nim | 2 +- .../connectivity/autonat/service.nim | 2 +- libp2p/protocols/pubsub/pubsub.nim | 6 +- libp2p/protocols/pubsub/pubsubpeer.nim | 7 + libp2p/protocols/secure/secure.nim | 1 - libp2p/stream/connection.nim | 31 +- libp2p/switch.nim | 39 ++- libp2p/transports/tortransport.nim | 2 +- libp2p/transports/transport.nim | 5 +- libp2p/upgrademngrs/muxedupgrade.nim | 158 ++------- libp2p/upgrademngrs/upgrade.nim | 39 +-- tests/config.nims | 1 + tests/pubsub/testgossipinternal.nim | 3 +- tests/pubsub/testgossipsub.nim | 53 ++- tests/pubsub/utils.nim | 18 +- tests/testconnmngr.nim | 181 ++++------ tests/testidentify.nim | 2 +- tests/testmultistream.nim | 3 +- tests/testnoise.nim | 7 +- tests/testpeerstore.nim | 8 +- 28 files changed, 467 insertions(+), 753 deletions(-) diff --git a/config.nims b/config.nims index 50672979c..5d3e88d2a 100644 --- a/config.nims +++ b/config.nims @@ -12,6 +12,7 @@ switch("warning", "LockLevel:off") if (NimMajor, NimMinor) < (1, 6): --styleCheck:hint else: + switch("warningAsError", "UseBase:on") --styleCheck:error # Avoid some rare stack corruption while using exceptions with a SEH-enabled diff --git a/libp2p/builders.nim b/libp2p/builders.nim index 487763c66..8dd10913b 100644 --- a/libp2p/builders.nim +++ b/libp2p/builders.nim @@ -230,7 +230,7 @@ proc build*(b: SwitchBuilder): Switch identify = Identify.new(peerInfo, b.sendSignedPeerRecord) connManager = ConnManager.new(b.maxConnsPerPeer, b.maxConnections, b.maxIn, b.maxOut) ms = MultistreamSelect.new() - muxedUpgrade = MuxedUpgrade.new(identify, b.muxers, secureManagerInstances, connManager, ms) + muxedUpgrade = MuxedUpgrade.new(b.muxers, secureManagerInstances, connManager, ms) let transports = block: @@ -247,14 +247,13 @@ proc build*(b: SwitchBuilder): Switch let peerStore = if isSome(b.peerStoreCapacity): - PeerStore.new(b.peerStoreCapacity.get()) + PeerStore.new(identify, b.peerStoreCapacity.get()) else: - PeerStore.new() + PeerStore.new(identify) let switch = newSwitch( peerInfo = peerInfo, transports = transports, - identity = identify, secureManagers = secureManagerInstances, connManager = connManager, ms = ms, @@ -262,6 +261,8 @@ proc build*(b: SwitchBuilder): Switch peerStore = peerStore, services = b.services) + switch.mount(identify) + if b.autonat: let autonat = Autonat.new(switch) switch.mount(autonat) diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index ac3a0b69a..bde345046 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -55,7 +55,6 @@ type PeerEventKind* {.pure.} = enum Left, - Identified, Joined PeerEvent* = object @@ -68,19 +67,14 @@ type PeerEventHandler* = proc(peerId: PeerId, event: PeerEvent): Future[void] {.gcsafe, raises: [Defect].} - MuxerHolder = object - muxer: Muxer - handle: Future[void] - ConnManager* = ref object of RootObj maxConnsPerPeer: int inSema*: AsyncSemaphore outSema*: AsyncSemaphore - conns: Table[PeerId, HashSet[Connection]] - muxed: Table[Connection, MuxerHolder] + muxed: Table[PeerId, seq[Muxer]] connEvents: array[ConnEventKind, OrderedSet[ConnEventHandler]] peerEvents: array[PeerEventKind, OrderedSet[PeerEventHandler]] - expectedConnectionsOverLimit*: Table[(PeerId, Direction), Future[Connection]] + expectedConnectionsOverLimit*: Table[(PeerId, Direction), Future[Muxer]] peerStore*: PeerStore ConnectionSlot* = object @@ -110,12 +104,12 @@ proc new*(C: type ConnManager, outSema: outSema) proc connCount*(c: ConnManager, peerId: PeerId): int = - c.conns.getOrDefault(peerId).len + c.muxed.getOrDefault(peerId).len proc connectedPeers*(c: ConnManager, dir: Direction): seq[PeerId] = var peers = newSeq[PeerId]() - for peerId, conns in c.conns: - if conns.anyIt(it.dir == dir): + for peerId, mux in c.muxed: + if mux.anyIt(it.connection.dir == dir): peers.add(peerId) return peers @@ -202,14 +196,6 @@ proc triggerPeerEvents*(c: ConnManager, return try: - let count = c.connCount(peerId) - if event.kind == PeerEventKind.Joined and count != 1: - trace "peer already joined", peer = peerId, event = $event - return - elif event.kind == PeerEventKind.Left and count != 0: - trace "peer still connected or already left", peer = peerId, event = $event - return - trace "triggering peer events", peer = peerId, event = $event var peerEvents: seq[Future[void]] @@ -222,13 +208,13 @@ proc triggerPeerEvents*(c: ConnManager, except CatchableError as exc: # handlers should not raise! warn "Exception in triggerPeerEvents", exc = exc.msg, peer = peerId -proc expectConnection*(c: ConnManager, p: PeerId, dir: Direction): Future[Connection] {.async.} = +proc expectConnection*(c: ConnManager, p: PeerId, dir: Direction): Future[Muxer] {.async.} = ## Wait for a peer to connect to us. This will bypass the `MaxConnectionsPerPeer` let key = (p, dir) if key in c.expectedConnectionsOverLimit: raise newException(AlreadyExpectingConnectionError, "Already expecting an incoming connection from that peer") - let future = newFuture[Connection]() + let future = newFuture[Muxer]() c.expectedConnectionsOverLimit[key] = future try: @@ -236,18 +222,8 @@ proc expectConnection*(c: ConnManager, p: PeerId, dir: Direction): Future[Connec finally: c.expectedConnectionsOverLimit.del(key) -proc contains*(c: ConnManager, conn: Connection): bool = - ## checks if a connection is being tracked by the - ## connection manager - ## - - if isNil(conn): - return - - return conn in c.conns.getOrDefault(conn.peerId) - proc contains*(c: ConnManager, peerId: PeerId): bool = - peerId in c.conns + peerId in c.muxed proc contains*(c: ConnManager, muxer: Muxer): bool = ## checks if a muxer is being tracked by the connection @@ -255,185 +231,134 @@ proc contains*(c: ConnManager, muxer: Muxer): bool = ## if isNil(muxer): - return + return false let conn = muxer.connection - if conn notin c: - return + return muxer in c.muxed.getOrDefault(conn.peerId) - if conn notin c.muxed: - return +proc closeMuxer(muxer: Muxer) {.async.} = + trace "Cleaning up muxer", m = muxer - return muxer == c.muxed.getOrDefault(conn).muxer - -proc closeMuxerHolder(muxerHolder: MuxerHolder) {.async.} = - trace "Cleaning up muxer", m = muxerHolder.muxer - - await muxerHolder.muxer.close() - if not(isNil(muxerHolder.handle)): + await muxer.close() + if not(isNil(muxer.handler)): try: - await muxerHolder.handle # TODO noraises? + await muxer.handler # TODO noraises? except CatchableError as exc: trace "Exception in close muxer handler", exc = exc.msg - trace "Cleaned up muxer", m = muxerHolder.muxer - -proc delConn(c: ConnManager, conn: Connection) = - let peerId = conn.peerId - c.conns.withValue(peerId, peerConns): - peerConns[].excl(conn) - - if peerConns[].len == 0: - c.conns.del(peerId) # invalidates `peerConns` - - libp2p_peers.set(c.conns.len.int64) - trace "Removed connection", conn - -proc cleanupConn(c: ConnManager, conn: Connection) {.async.} = - ## clean connection's resources such as muxers and streams - - if isNil(conn): - trace "Wont cleanup a nil connection" - return - - # Remove connection from all tables without async breaks - var muxer = some(MuxerHolder()) - if not c.muxed.pop(conn, muxer.get()): - muxer = none(MuxerHolder) - - delConn(c, conn) + trace "Cleaned up muxer", m = muxer +proc muxCleanup(c: ConnManager, mux: Muxer) {.async.} = try: - if muxer.isSome: - await closeMuxerHolder(muxer.get()) - finally: - await conn.close() + trace "Triggering disconnect events", mux + let peerId = mux.connection.peerId - trace "Connection cleaned up", conn + let muxers = c.muxed.getOrDefault(peerId).filterIt(it != mux) + if muxers.len > 0: + c.muxed[peerId] = muxers + else: + c.muxed.del(peerId) + libp2p_peers.set(c.muxed.len.int64) + await c.triggerPeerEvents(peerId, PeerEvent(kind: PeerEventKind.Left)) -proc onConnUpgraded(c: ConnManager, conn: Connection) {.async.} = - try: - trace "Triggering connect events", conn - conn.upgrade() + if not(c.peerStore.isNil): + c.peerStore.cleanup(peerId) - let peerId = conn.peerId - await c.triggerPeerEvents( - peerId, PeerEvent(kind: PeerEventKind.Joined, initiator: conn.dir == Direction.Out)) - - await c.triggerConnEvent( - peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: conn.dir == Direction.In)) - except CatchableError as exc: - # This is top-level procedure which will work as separate task, so it - # do not need to propagate CancelledError and should handle other errors - warn "Unexpected exception in switch peer connection cleanup", - conn, msg = exc.msg - -proc peerCleanup(c: ConnManager, conn: Connection) {.async.} = - try: - trace "Triggering disconnect events", conn - let peerId = conn.peerId await c.triggerConnEvent( peerId, ConnEvent(kind: ConnEventKind.Disconnected)) - await c.triggerPeerEvents(peerId, PeerEvent(kind: PeerEventKind.Left)) - - if not(c.peerStore.isNil): - c.peerStore.cleanup(peerId) except CatchableError as exc: # This is top-level procedure which will work as separate task, so it # do not need to propagate CancelledError and should handle other errors warn "Unexpected exception peer cleanup handler", - conn, msg = exc.msg + mux, msg = exc.msg -proc onClose(c: ConnManager, conn: Connection) {.async.} = +proc onClose(c: ConnManager, mux: Muxer) {.async.} = ## connection close even handler ## ## triggers the connections resource cleanup ## try: - await conn.join() - trace "Connection closed, cleaning up", conn - await c.cleanupConn(conn) - except CancelledError: - # This is top-level procedure which will work as separate task, so it - # do not need to propagate CancelledError. - debug "Unexpected cancellation in connection manager's cleanup", conn + await mux.connection.join() + trace "Connection closed, cleaning up", mux except CatchableError as exc: debug "Unexpected exception in connection manager's cleanup", - errMsg = exc.msg, conn + errMsg = exc.msg, mux finally: - trace "Triggering peerCleanup", conn - asyncSpawn c.peerCleanup(conn) + await c.muxCleanup(mux) -proc selectConn*(c: ConnManager, +proc selectMuxer*(c: ConnManager, peerId: PeerId, - dir: Direction): Connection = + dir: Direction): Muxer = ## Select a connection for the provided peer and direction ## let conns = toSeq( - c.conns.getOrDefault(peerId)) - .filterIt( it.dir == dir ) + c.muxed.getOrDefault(peerId)) + .filterIt( it.connection.dir == dir ) if conns.len > 0: return conns[0] -proc selectConn*(c: ConnManager, peerId: PeerId): Connection = +proc selectMuxer*(c: ConnManager, peerId: PeerId): Muxer = ## Select a connection for the provided giving priority ## to outgoing connections ## - var conn = c.selectConn(peerId, Direction.Out) - if isNil(conn): - conn = c.selectConn(peerId, Direction.In) - if isNil(conn): + var mux = c.selectMuxer(peerId, Direction.Out) + if isNil(mux): + mux = c.selectMuxer(peerId, Direction.In) + if isNil(mux): trace "connection not found", peerId + return mux - return conn - -proc selectMuxer*(c: ConnManager, conn: Connection): Muxer = - ## select the muxer for the provided connection +proc storeMuxer*(c: ConnManager, + muxer: Muxer) + {.raises: [Defect, CatchableError].} = + ## store the connection and muxer ## - if isNil(conn): - return + if isNil(muxer): + raise newException(LPError, "muxer cannot be nil") - if conn in c.muxed: - return c.muxed.getOrDefault(conn).muxer - else: - debug "no muxer for connection", conn + if isNil(muxer.connection): + raise newException(LPError, "muxer's connection cannot be nil") -proc storeConn*(c: ConnManager, conn: Connection) - {.raises: [Defect, LPError].} = - ## store a connection - ## - - if isNil(conn): - raise newException(LPError, "Connection cannot be nil") - - if conn.closed or conn.atEof: + if muxer.connection.closed or muxer.connection.atEof: raise newException(LPError, "Connection closed or EOF") - let peerId = conn.peerId + let + peerId = muxer.connection.peerId + dir = muxer.connection.dir # we use getOrDefault in the if below instead of [] to avoid the KeyError - if c.conns.getOrDefault(peerId).len > c.maxConnsPerPeer: - let key = (peerId, conn.dir) + if c.muxed.getOrDefault(peerId).len > c.maxConnsPerPeer: + let key = (peerId, dir) let expectedConn = c.expectedConnectionsOverLimit.getOrDefault(key) if expectedConn != nil and not expectedConn.finished: - expectedConn.complete(conn) + expectedConn.complete(muxer) else: debug "Too many connections for peer", - conn, conns = c.conns.getOrDefault(peerId).len + conns = c.muxed.getOrDefault(peerId).len raise newTooManyConnectionsError() - c.conns.mgetOrPut(peerId, HashSet[Connection]()).incl(conn) - libp2p_peers.set(c.conns.len.int64) + assert muxer notin c.muxed.getOrDefault(peerId) - # Launch on close listener - # All the errors are handled inside `onClose()` procedure. - asyncSpawn c.onClose(conn) + let + newPeer = peerId notin c.muxed + assert newPeer or c.muxed[peerId].len > 0 + c.muxed.mgetOrPut(peerId, newSeq[Muxer]()).add(muxer) + libp2p_peers.set(c.muxed.len.int64) - trace "Stored connection", - conn, direction = $conn.dir, connections = c.conns.len + asyncSpawn c.triggerConnEvent( + peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: dir == Direction.In)) + + if newPeer: + asyncSpawn c.triggerPeerEvents( + peerId, PeerEvent(kind: PeerEventKind.Joined, initiator: dir == Direction.Out)) + + asyncSpawn c.onClose(muxer) + + trace "Stored muxer", + muxer, direction = $muxer.connection.dir, peers = c.muxed.len proc getIncomingSlot*(c: ConnManager): Future[ConnectionSlot] {.async.} = await c.inSema.acquire() @@ -476,39 +401,17 @@ proc trackConnection*(cs: ConnectionSlot, conn: Connection) = asyncSpawn semaphoreMonitor() -proc storeMuxer*(c: ConnManager, - muxer: Muxer, - handle: Future[void] = nil) - {.raises: [Defect, CatchableError].} = - ## store the connection and muxer - ## - - if isNil(muxer): - raise newException(CatchableError, "muxer cannot be nil") - - if isNil(muxer.connection): - raise newException(CatchableError, "muxer's connection cannot be nil") - - if muxer.connection notin c: - raise newException(CatchableError, "cant add muxer for untracked connection") - - c.muxed[muxer.connection] = MuxerHolder( - muxer: muxer, - handle: handle) - - trace "Stored muxer", - muxer, handle = not handle.isNil, connections = c.conns.len - - asyncSpawn c.onConnUpgraded(muxer.connection) +proc trackMuxer*(cs: ConnectionSlot, mux: Muxer) = + if isNil(mux): + cs.release() + return + cs.trackConnection(mux.connection) proc getStream*(c: ConnManager, - peerId: PeerId, - dir: Direction): Future[Connection] {.async, gcsafe.} = - ## get a muxed stream for the provided peer - ## with the given direction + muxer: Muxer): Future[Connection] {.async, gcsafe.} = + ## get a muxed stream for the passed muxer ## - let muxer = c.selectMuxer(c.selectConn(peerId, dir)) if not(isNil(muxer)): return await muxer.newStream() @@ -517,40 +420,25 @@ proc getStream*(c: ConnManager, ## get a muxed stream for the passed peer from any connection ## - let muxer = c.selectMuxer(c.selectConn(peerId)) - if not(isNil(muxer)): - return await muxer.newStream() + return await c.getStream(c.selectMuxer(peerId)) proc getStream*(c: ConnManager, - conn: Connection): Future[Connection] {.async, gcsafe.} = - ## get a muxed stream for the passed connection + peerId: PeerId, + dir: Direction): Future[Connection] {.async, gcsafe.} = + ## get a muxed stream for the passed peer from a connection with `dir` ## - let muxer = c.selectMuxer(conn) - if not(isNil(muxer)): - return await muxer.newStream() + return await c.getStream(c.selectMuxer(peerId, dir)) + proc dropPeer*(c: ConnManager, peerId: PeerId) {.async.} = ## drop connections and cleanup resources for peer ## trace "Dropping peer", peerId - let conns = c.conns.getOrDefault(peerId) - for conn in conns: - trace "Removing connection", conn - delConn(c, conn) - - var muxers: seq[MuxerHolder] - for conn in conns: - if conn in c.muxed: - muxers.add c.muxed[conn] - c.muxed.del(conn) + let muxers = c.muxed.getOrDefault(peerId) for muxer in muxers: - await closeMuxerHolder(muxer) - - for conn in conns: - await conn.close() - trace "Dropped peer", peerId + await closeMuxer(muxer) trace "Peer dropped", peerId @@ -560,9 +448,6 @@ proc close*(c: ConnManager) {.async.} = ## trace "Closing ConnManager" - let conns = c.conns - c.conns.clear() - let muxed = c.muxed c.muxed.clear() @@ -572,12 +457,9 @@ proc close*(c: ConnManager) {.async.} = for _, fut in expected: await fut.cancelAndWait() - for _, muxer in muxed: - await closeMuxerHolder(muxer) - - for _, conns2 in conns: - for conn in conns2: - await conn.close() + for _, muxers in muxed: + for mux in muxers: + await closeMuxer(mux) trace "Closed ConnManager" diff --git a/libp2p/dialer.nim b/libp2p/dialer.nim index 3ca4b012f..a5e7e0c27 100644 --- a/libp2p/dialer.nim +++ b/libp2p/dialer.nim @@ -17,7 +17,9 @@ import pkg/[chronos, import dial, peerid, peerinfo, + peerstore, multicodec, + muxers/muxer, multistream, connmanager, stream/connection, @@ -41,10 +43,10 @@ type Dialer* = ref object of Dial localPeerId*: PeerId - ms: MultistreamSelect connManager: ConnManager dialLock: Table[PeerId, AsyncLock] transports: seq[Transport] + peerStore: PeerStore nameResolver: NameResolver proc dialAndUpgrade( @@ -52,7 +54,7 @@ proc dialAndUpgrade( peerId: Opt[PeerId], hostname: string, address: MultiAddress): - Future[Connection] {.async.} = + Future[Muxer] {.async.} = for transport in self.transports: # for each transport if transport.handles(address): # check if it can dial it @@ -75,7 +77,7 @@ proc dialAndUpgrade( libp2p_successful_dials.inc() - let conn = + let mux = try: await transport.upgradeOutgoing(dialed, peerId) except CatchableError as exc: @@ -89,9 +91,9 @@ proc dialAndUpgrade( # Try other address return nil - doAssert not isNil(conn), "connection died after upgradeOutgoing" - debug "Dial successful", conn, peerId = conn.peerId - return conn + doAssert not isNil(mux), "connection died after upgradeOutgoing" + debug "Dial successful", peerId = mux.connection.peerId + return mux return nil proc expandDnsAddr( @@ -126,7 +128,7 @@ proc dialAndUpgrade( self: Dialer, peerId: Opt[PeerId], addrs: seq[MultiAddress]): - Future[Connection] {.async.} = + Future[Muxer] {.async.} = debug "Dialing peer", peerId @@ -147,21 +149,13 @@ proc dialAndUpgrade( if not isNil(result): return result -proc tryReusingConnection(self: Dialer, peerId: PeerId): Future[Opt[Connection]] {.async.} = - var conn = self.connManager.selectConn(peerId) - if conn == nil: - return Opt.none(Connection) +proc tryReusingConnection(self: Dialer, peerId: PeerId): Future[Opt[Muxer]] {.async.} = + let muxer = self.connManager.selectMuxer(peerId) + if muxer == nil: + return Opt.none(Muxer) - if conn.atEof or conn.closed: - # This connection should already have been removed from the connection - # manager - it's essentially a bug that we end up here - we'll fail - # for now, hoping that this will clean themselves up later... - warn "dead connection in connection manager", conn - await conn.close() - raise newException(DialFailedError, "Zombie connection encountered") - - trace "Reusing existing connection", conn, direction = $conn.dir - return Opt.some(conn) + trace "Reusing existing connection", muxer, direction = $muxer.connection.dir + return Opt.some(muxer) proc internalConnect( self: Dialer, @@ -169,7 +163,7 @@ proc internalConnect( addrs: seq[MultiAddress], forceDial: bool, reuseConnection = true): - Future[Connection] {.async.} = + Future[Muxer] {.async.} = if Opt.some(self.localPeerId) == peerId: raise newException(CatchableError, "can't dial self!") @@ -179,32 +173,30 @@ proc internalConnect( await lock.acquire() if peerId.isSome and reuseConnection: - let connOpt = await self.tryReusingConnection(peerId.get()) - if connOpt.isSome: - return connOpt.get() + let muxOpt = await self.tryReusingConnection(peerId.get()) + if muxOpt.isSome: + return muxOpt.get() let slot = self.connManager.getOutgoingSlot(forceDial) - let conn = + let muxed = try: await self.dialAndUpgrade(peerId, addrs) except CatchableError as exc: slot.release() raise exc - slot.trackConnection(conn) - if isNil(conn): # None of the addresses connected + slot.trackMuxer(muxed) + if isNil(muxed): # None of the addresses connected raise newException(DialFailedError, "Unable to establish outgoing link") - # A disconnect could have happened right after - # we've added the connection so we check again - # to prevent races due to that. - if conn.closed() or conn.atEof(): - # This can happen when the other ends drops us - # before we get a chance to return the connection - # back to the dialer. - trace "Connection dead on arrival", conn - raise newLPStreamClosedError() + try: + self.connManager.storeMuxer(muxed) + await self.peerStore.identify(muxed) + except CatchableError as exc: + trace "Failed to finish outgoung upgrade", err=exc.msg + await muxed.close() + raise exc - return conn + return muxed finally: if lock.locked(): lock.release() @@ -235,21 +227,21 @@ method connect*( return (await self.internalConnect( Opt.some(fullAddress.get()[0]), @[fullAddress.get()[1]], - false)).peerId + false)).connection.peerId else: if allowUnknownPeerId == false: raise newException(DialFailedError, "Address without PeerID and unknown peer id disabled!") return (await self.internalConnect( Opt.none(PeerId), @[address], - false)).peerId + false)).connection.peerId proc negotiateStream( self: Dialer, conn: Connection, protos: seq[string]): Future[Connection] {.async.} = trace "Negotiating stream", conn, protos - let selected = await self.ms.select(conn, protos) + let selected = await MultistreamSelect.select(conn, protos) if not protos.contains(selected): await conn.closeWithEOF() raise newException(DialFailedError, "Unable to select sub-protocol " & $protos) @@ -267,11 +259,11 @@ method tryDial*( trace "Check if it can dial", peerId, addrs try: - let conn = await self.dialAndUpgrade(Opt.some(peerId), addrs) - if conn.isNil(): + let mux = await self.dialAndUpgrade(Opt.some(peerId), addrs) + if mux.isNil(): raise newException(DialFailedError, "No valid multiaddress") - await conn.close() - return conn.observedAddr + await mux.close() + return mux.connection.observedAddr except CancelledError as exc: raise exc except CatchableError as exc: @@ -303,7 +295,7 @@ method dial*( ## var - conn: Connection + conn: Muxer stream: Connection proc cleanup() {.async.} = @@ -340,12 +332,12 @@ proc new*( T: type Dialer, localPeerId: PeerId, connManager: ConnManager, + peerStore: PeerStore, transports: seq[Transport], - ms: MultistreamSelect, nameResolver: NameResolver = nil): Dialer = T(localPeerId: localPeerId, connManager: connManager, transports: transports, - ms: ms, + peerStore: peerStore, nameResolver: nameResolver) diff --git a/libp2p/multistream.nim b/libp2p/multistream.nim index f34339ac3..9f9e75f1b 100644 --- a/libp2p/multistream.nim +++ b/libp2p/multistream.nim @@ -21,12 +21,11 @@ logScope: topics = "libp2p multistream" const - MsgSize* = 1024 - Codec* = "/multistream/1.0.0" + MsgSize = 1024 + Codec = "/multistream/1.0.0" - MSCodec* = "\x13" & Codec & "\n" - Na* = "\x03na\n" - Ls* = "\x03ls\n" + Na = "na\n" + Ls = "ls\n" type Matcher* = proc (proto: string): bool {.gcsafe, raises: [Defect].} @@ -45,7 +44,7 @@ type proc new*(T: typedesc[MultistreamSelect]): T = T( - codec: MSCodec, + codec: Codec, ) template validateSuffix(str: string): untyped = @@ -54,13 +53,13 @@ template validateSuffix(str: string): untyped = else: raise newException(MultiStreamError, "MultistreamSelect failed, malformed message") -proc select*(m: MultistreamSelect, +proc select*(_: MultistreamSelect | type MultistreamSelect, conn: Connection, proto: seq[string]): Future[string] {.async.} = - trace "initiating handshake", conn, codec = m.codec + trace "initiating handshake", conn, codec = Codec ## select a remote protocol - await conn.write(m.codec) # write handshake + await conn.writeLp(Codec & "\n") # write handshake if proto.len() > 0: trace "selecting proto", conn, proto = proto[0] await conn.writeLp((proto[0] & "\n")) # select proto @@ -102,13 +101,13 @@ proc select*(m: MultistreamSelect, # No alternatives, fail return "" -proc select*(m: MultistreamSelect, +proc select*(_: MultistreamSelect | type MultistreamSelect, conn: Connection, proto: string): Future[bool] {.async.} = if proto.len > 0: - return (await m.select(conn, @[proto])) == proto + return (await MultistreamSelect.select(conn, @[proto])) == proto else: - return (await m.select(conn, @[])) == Codec + return (await MultistreamSelect.select(conn, @[])) == Codec proc select*(m: MultistreamSelect, conn: Connection): Future[bool] = m.select(conn, "") @@ -119,7 +118,7 @@ proc list*(m: MultistreamSelect, if not await m.select(conn): return - await conn.write(Ls) # send ls + await conn.writeLp(Ls) # send ls var list = newSeq[string]() let ms = string.fromBytes(await conn.readLp(MsgSize)) @@ -129,68 +128,86 @@ proc list*(m: MultistreamSelect, result = list +proc handle*( + _: type MultistreamSelect, + conn: Connection, + protos: seq[string], + matchers = newSeq[Matcher](), + active: bool = false, + ): Future[string] {.async, gcsafe.} = + trace "Starting multistream negotiation", conn, handshaked = active + var handshaked = active + while not conn.atEof: + var ms = string.fromBytes(await conn.readLp(MsgSize)) + validateSuffix(ms) + + if not handshaked and ms != Codec: + debug "expected handshake message", conn, instead=ms + raise newException(CatchableError, + "MultistreamSelect handling failed, invalid first message") + + trace "handle: got request", conn, ms + if ms.len() <= 0: + trace "handle: invalid proto", conn + await conn.writeLp(Na) + + case ms: + of "ls": + trace "handle: listing protos", conn + #TODO this doens't seem to follow spec, each protocol + # should be length prefixed. Not very important + # since LS is getting deprecated + await conn.writeLp(protos.join("\n") & "\n") + of Codec: + if not handshaked: + await conn.writeLp(Codec & "\n") + handshaked = true + else: + trace "handle: sending `na` for duplicate handshake while handshaked", + conn + await conn.writeLp(Na) + elif ms in protos or matchers.anyIt(it(ms)): + trace "found handler", conn, protocol = ms + await conn.writeLp(ms & "\n") + conn.protocol = ms + return ms + else: + trace "no handlers", conn, protocol = ms + await conn.writeLp(Na) + proc handle*(m: MultistreamSelect, conn: Connection, active: bool = false) {.async, gcsafe.} = trace "Starting multistream handler", conn, handshaked = active - var handshaked = active + var + handshaked = active + protos: seq[string] + matchers: seq[Matcher] + for h in m.handlers: + if not isNil(h.match): + matchers.add(h.match) + for proto in h.protos: + protos.add(proto) + try: - while not conn.atEof: - var ms = string.fromBytes(await conn.readLp(MsgSize)) - validateSuffix(ms) + let ms = await MultistreamSelect.handle(conn, protos, matchers, active) + for h in m.handlers: + if (not isNil(h.match) and h.match(ms)) or h.protos.contains(ms): + trace "found handler", conn, protocol = ms - if not handshaked and ms != Codec: - notice "expected handshake message", conn, instead=ms - raise newException(CatchableError, - "MultistreamSelect handling failed, invalid first message") - - trace "handle: got request", conn, ms - if ms.len() <= 0: - trace "handle: invalid proto", conn - await conn.write(Na) - - if m.handlers.len() == 0: - trace "handle: sending `na` for protocol", conn, protocol = ms - await conn.write(Na) - continue - - case ms: - of "ls": - trace "handle: listing protos", conn - var protos = "" - for h in m.handlers: - for proto in h.protos: - protos &= (proto & "\n") - await conn.writeLp(protos) - of Codec: - if not handshaked: - await conn.write(m.codec) - handshaked = true - else: - trace "handle: sending `na` for duplicate handshake while handshaked", - conn - await conn.write(Na) - else: - for h in m.handlers: - if (not isNil(h.match) and h.match(ms)) or h.protos.contains(ms): - trace "found handler", conn, protocol = ms - - var protocolHolder = h - let maxIncomingStreams = protocolHolder.protocol.maxIncomingStreams - if protocolHolder.openedStreams.getOrDefault(conn.peerId) >= maxIncomingStreams: - debug "Max streams for protocol reached, blocking new stream", - conn, protocol = ms, maxIncomingStreams - return - protocolHolder.openedStreams.inc(conn.peerId) - try: - await conn.writeLp(ms & "\n") - conn.protocol = ms - await protocolHolder.protocol.handler(conn, ms) - finally: - protocolHolder.openedStreams.inc(conn.peerId, -1) - if protocolHolder.openedStreams[conn.peerId] == 0: - protocolHolder.openedStreams.del(conn.peerId) - return - debug "no handlers", conn, protocol = ms - await conn.write(Na) + var protocolHolder = h + let maxIncomingStreams = protocolHolder.protocol.maxIncomingStreams + if protocolHolder.openedStreams.getOrDefault(conn.peerId) >= maxIncomingStreams: + debug "Max streams for protocol reached, blocking new stream", + conn, protocol = ms, maxIncomingStreams + return + protocolHolder.openedStreams.inc(conn.peerId) + try: + await protocolHolder.protocol.handler(conn, ms) + finally: + protocolHolder.openedStreams.inc(conn.peerId, -1) + if protocolHolder.openedStreams[conn.peerId] == 0: + protocolHolder.openedStreams.del(conn.peerId) + return + debug "no handlers", conn, ms except CancelledError as exc: raise exc except CatchableError as exc: diff --git a/libp2p/muxers/muxer.nim b/libp2p/muxers/muxer.nim index ee9dd9cac..0221ed743 100644 --- a/libp2p/muxers/muxer.nim +++ b/libp2p/muxers/muxer.nim @@ -32,24 +32,28 @@ type Muxer* = ref object of RootObj streamHandler*: StreamHandler + handler*: Future[void] connection*: Connection # user provider proc that returns a constructed Muxer MuxerConstructor* = proc(conn: Connection): Muxer {.gcsafe, closure, raises: [Defect].} # this wraps a creator proc that knows how to make muxers - MuxerProvider* = ref object of LPProtocol + MuxerProvider* = object newMuxer*: MuxerConstructor - streamHandler*: StreamHandler # triggered every time there is a new stream, called for any muxer instance - muxerHandler*: MuxerHandler # triggered every time there is a new muxed connection created + codec*: string -func shortLog*(m: Muxer): auto = shortLog(m.connection) +func shortLog*(m: Muxer): auto = + if isNil(m): "nil" + else: shortLog(m.connection) chronicles.formatIt(Muxer): shortLog(it) # muxer interface method newStream*(m: Muxer, name: string = "", lazy: bool = false): Future[Connection] {.base, async, gcsafe.} = discard -method close*(m: Muxer) {.base, async, gcsafe.} = discard +method close*(m: Muxer) {.base, async, gcsafe.} = + if not isNil(m.connection): + await m.connection.close() method handle*(m: Muxer): Future[void] {.base, async, gcsafe.} = discard proc new*( @@ -57,36 +61,5 @@ proc new*( creator: MuxerConstructor, codec: string): T {.gcsafe.} = - let muxerProvider = T(newMuxer: creator) - muxerProvider.codec = codec - muxerProvider.init() + let muxerProvider = T(newMuxer: creator, codec: codec) muxerProvider - -method init(c: MuxerProvider) = - proc handler(conn: Connection, proto: string) {.async, gcsafe, closure.} = - trace "starting muxer handler", proto=proto, conn - try: - let - muxer = c.newMuxer(conn) - - if not isNil(c.streamHandler): - muxer.streamHandler = c.streamHandler - - var futs = newSeq[Future[void]]() - futs &= muxer.handle() - - # finally await both the futures - if not isNil(c.muxerHandler): - await c.muxerHandler(muxer) - when defined(libp2p_agents_metrics): - conn.shortAgent = muxer.connection.shortAgent - - checkFutures(await allFinished(futs)) - except CancelledError as exc: - raise exc - except CatchableError as exc: - trace "exception in muxer handler", exc = exc.msg, conn, proto - finally: - await conn.close() - - c.handler = handler diff --git a/libp2p/muxers/yamux/yamux.nim b/libp2p/muxers/yamux/yamux.nim index f60cf72f9..8d94b719d 100644 --- a/libp2p/muxers/yamux/yamux.nim +++ b/libp2p/muxers/yamux/yamux.nim @@ -356,6 +356,8 @@ proc open*(channel: YamuxChannel) {.async, gcsafe.} = channel.opened = true await channel.conn.write(YamuxHeader.data(channel.id, 0, {if channel.isSrc: Syn else: Ack})) +method getWrapped*(channel: YamuxChannel): Connection = channel.conn + type Yamux* = ref object of Muxer channels: Table[uint32, YamuxChannel] diff --git a/libp2p/peerstore.nim b/libp2p/peerstore.nim index 23f31fb2c..5cb1df44b 100644 --- a/libp2p/peerstore.nim +++ b/libp2p/peerstore.nim @@ -28,11 +28,16 @@ else: import std/[tables, sets, options, macros], + chronos, ./crypto/crypto, ./protocols/identify, + ./protocols/protocol, ./peerid, ./peerinfo, ./routing_record, ./multiaddress, + ./stream/connection, + ./multistream, + ./muxers/muxer, utility type @@ -70,11 +75,15 @@ type PeerStore* {.public.} = ref object books: Table[string, BasePeerBook] + identify: Identify capacity*: int toClean*: seq[PeerId] -proc new*(T: type PeerStore, capacity = 1000): PeerStore {.public.} = - T(capacity: capacity) +proc new*(T: type PeerStore, identify: Identify, capacity = 1000): PeerStore {.public.} = + T( + identify: identify, + capacity: capacity + ) ######################### # Generic Peer Book API # @@ -186,3 +195,28 @@ proc cleanup*( while peerStore.toClean.len > peerStore.capacity: peerStore.del(peerStore.toClean[0]) peerStore.toClean.delete(0) + +proc identify*( + peerStore: PeerStore, + muxer: Muxer) {.async.} = + + # new stream for identify + var stream = await muxer.newStream() + if stream == nil: + return + + try: + if (await MultistreamSelect.select(stream, peerStore.identify.codec())): + let info = await peerStore.identify.identify(stream, stream.peerId) + + when defined(libp2p_agents_metrics): + var knownAgent = "unknown" + if info.agentVersion.isSome and info.agentVersion.get().len > 0: + let shortAgent = info.agentVersion.get().split("/")[0].safeToLowerAscii() + if shortAgent.isOk() and KnownLibP2PAgentsSeq.contains(shortAgent.get()): + knownAgent = shortAgent.get() + muxer.connection.setShortAgent(knownAgent) + + peerStore.updatePeerInfo(info) + finally: + await stream.closeWithEOF() diff --git a/libp2p/protocols/connectivity/autonat/client.nim b/libp2p/protocols/connectivity/autonat/client.nim index 13176259c..8a74ef009 100644 --- a/libp2p/protocols/connectivity/autonat/client.nim +++ b/libp2p/protocols/connectivity/autonat/client.nim @@ -65,7 +65,7 @@ method dialMe*(self: AutonatClient, switch: Switch, pid: PeerId, addrs: seq[Mult await conn.close() incomingConnection.cancel() # Safer to always try to cancel cause we aren't sure if the peer dialled us or not if incomingConnection.completed(): - await (await incomingConnection).close() + await (await incomingConnection).connection.close() trace "sending Dial", addrs = switch.peerInfo.addrs await conn.sendDial(switch.peerInfo.peerId, switch.peerInfo.addrs) let response = getResponseOrRaise(AutonatMsg.decode(await conn.readLp(1024))) diff --git a/libp2p/protocols/connectivity/autonat/service.nim b/libp2p/protocols/connectivity/autonat/service.nim index 6d6ca7143..f75907d33 100644 --- a/libp2p/protocols/connectivity/autonat/service.nim +++ b/libp2p/protocols/connectivity/autonat/service.nim @@ -81,7 +81,7 @@ proc hasEnoughIncomingSlots(switch: Switch): bool = return switch.connManager.slotsAvailable(In) >= 2 proc doesPeerHaveIncomingConn(switch: Switch, peerId: PeerId): bool = - return switch.connManager.selectConn(peerId, In) != nil + return switch.connManager.selectMuxer(peerId, In) != nil proc handleAnswer(self: AutonatService, ans: NetworkReachability) {.async.} = diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index 2203f7754..fb7aea13b 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -406,7 +406,11 @@ method onTopicSubscription*(p: PubSub, topic: string, subscribed: bool) {.base, # Notify others that we are no longer interested in the topic for _, peer in p.peers: - p.sendSubs(peer, [topic], subscribed) + # If we don't have a sendConn yet, we will + # send the full sub list when we get the sendConn, + # so no need to send it here + if peer.hasSendConn: + p.sendSubs(peer, [topic], subscribed) if subscribed: libp2p_pubsub_subscriptions.inc() diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index ebdbd4d20..b925b2378 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -177,6 +177,10 @@ proc connectOnce(p: PubSubPeer): Future[void] {.async.} = # stop working so we make an effort to only keep a single channel alive trace "Get new send connection", p, newConn + + # Careful to race conditions here. + # Topic subscription relies on either connectedFut + # to be completed, or onEvent to be called later p.connectedFut.complete() p.sendConn = newConn p.address = if p.sendConn.observedAddr.isSome: some(p.sendConn.observedAddr.get) else: none(MultiAddress) @@ -217,6 +221,9 @@ proc connect*(p: PubSubPeer) = asyncSpawn connectImpl(p) +proc hasSendConn*(p: PubSubPeer): bool = + p.sendConn != nil + template sendMetrics(msg: RPCMsg): untyped = when defined(libp2p_expensive_metrics): for x in msg.messages: diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index 905bf281e..e915934c9 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -56,7 +56,6 @@ proc new*(T: type SecureConn, peerId: peerId, observedAddr: observedAddr, closeEvent: conn.closeEvent, - upgraded: conn.upgraded, timeout: timeout, dir: conn.dir) result.initStream() diff --git a/libp2p/stream/connection.nim b/libp2p/stream/connection.nim index 3f4730e58..86539a058 100644 --- a/libp2p/stream/connection.nim +++ b/libp2p/stream/connection.nim @@ -39,7 +39,6 @@ type timeoutHandler*: TimeoutHandler # timeout handler peerId*: PeerId observedAddr*: Opt[MultiAddress] - upgraded*: Future[void] protocol*: string # protocol used by the connection, used as tag for metrics transportDir*: Direction # The bottom level transport (generally the socket) direction when defined(libp2p_agents_metrics): @@ -47,22 +46,6 @@ type proc timeoutMonitor(s: Connection) {.async, gcsafe.} -proc isUpgraded*(s: Connection): bool = - if not isNil(s.upgraded): - return s.upgraded.finished - -proc upgrade*(s: Connection, failed: ref CatchableError = nil) = - if not isNil(s.upgraded): - if not isNil(failed): - s.upgraded.fail(failed) - return - - s.upgraded.complete() - -proc onUpgrade*(s: Connection) {.async.} = - if not isNil(s.upgraded): - await s.upgraded - func shortLog*(conn: Connection): string = try: if conn.isNil: "Connection(nil)" @@ -80,9 +63,6 @@ method initStream*(s: Connection) = doAssert(isNil(s.timerTaskFut)) - if isNil(s.upgraded): - s.upgraded = newFuture[void]() - if s.timeout > 0.millis: trace "Monitoring for timeout", s, timeout = s.timeout @@ -100,10 +80,6 @@ method closeImpl*(s: Connection): Future[void] = s.timerTaskFut.cancel() s.timerTaskFut = nil - if not isNil(s.upgraded) and not s.upgraded.finished: - s.upgraded.cancel() - s.upgraded = nil - trace "Closed connection", s procCall LPStream(s).closeImpl() @@ -158,6 +134,13 @@ proc timeoutMonitor(s: Connection) {.async, gcsafe.} = method getWrapped*(s: Connection): Connection {.base.} = doAssert(false, "not implemented!") +when defined(libp2p_agents_metrics): + proc setShortAgent*(s: Connection, shortAgent: string) = + var conn = s + while not isNil(conn): + conn.shortAgent = shortAgent + conn = conn.getWrapped() + proc new*(C: type Connection, peerId: PeerId, dir: Direction, diff --git a/libp2p/switch.nim b/libp2p/switch.nim index d28ae1b1a..672eca123 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -220,24 +220,27 @@ proc mount*[T: LPProtocol](s: Switch, proto: T, matcher: Matcher = nil) s.ms.addHandler(proto.codecs, proto, matcher) s.peerInfo.protocols.add(proto.codec) -proc upgradeMonitor(conn: Connection, upgrades: AsyncSemaphore) {.async.} = - ## monitor connection for upgrades - ## +proc upgrader(switch: Switch, trans: Transport, conn: Connection) {.async.} = + let muxed = await trans.upgradeIncoming(conn) + switch.connManager.storeMuxer(muxed) + await switch.peerStore.identify(muxed) + trace "Connection upgrade succeeded" + +proc upgradeMonitor( + switch: Switch, + trans: Transport, + conn: Connection, + upgrades: AsyncSemaphore) {.async.} = try: - # Since we don't control the flow of the - # upgrade, this timeout guarantees that a - # "hanged" remote doesn't hold the upgrade - # forever - await conn.onUpgrade.wait(30.seconds) # wait for connection to be upgraded - trace "Connection upgrade succeeded" + await switch.upgrader(trans, conn).wait(30.seconds) except CatchableError as exc: - libp2p_failed_upgrades_incoming.inc() + if exc isnot CancelledError: + libp2p_failed_upgrades_incoming.inc() if not isNil(conn): await conn.close() - trace "Exception awaiting connection upgrade", exc = exc.msg, conn finally: - upgrades.release() # don't forget to release the slot! + upgrades.release() proc accept(s: Switch, transport: Transport) {.async.} = # noraises ## switch accept loop, ran for every transport @@ -278,8 +281,7 @@ proc accept(s: Switch, transport: Transport) {.async.} = # noraises conn.transportDir = Direction.In debug "Accepted an incoming connection", conn - asyncSpawn upgradeMonitor(conn, upgrades) - asyncSpawn transport.upgradeIncoming(conn) + asyncSpawn s.upgradeMonitor(transport, conn, upgrades) except CancelledError as exc: trace "releasing semaphore on cancellation" upgrades.release() # always release the slot @@ -377,14 +379,13 @@ proc start*(s: Switch) {.async, gcsafe, public.} = proc newSwitch*(peerInfo: PeerInfo, transports: seq[Transport], - identity: Identify, secureManagers: openArray[Secure] = [], connManager: ConnManager, ms: MultistreamSelect, + peerStore: PeerStore, nameResolver: NameResolver = nil, - peerStore = PeerStore.new(), services = newSeq[Service]()): Switch - {.raises: [Defect, LPError], public.} = + {.raises: [Defect, LPError].} = if secureManagers.len == 0: raise newException(LPError, "Provide at least one secure manager") @@ -394,11 +395,9 @@ proc newSwitch*(peerInfo: PeerInfo, transports: transports, connManager: connManager, peerStore: peerStore, - dialer: Dialer.new(peerInfo.peerId, connManager, transports, ms, nameResolver), + dialer: Dialer.new(peerInfo.peerId, connManager, peerStore, transports, nameResolver), nameResolver: nameResolver, services: services) switch.connManager.peerStore = peerStore - switch.mount(identity) - return switch diff --git a/libp2p/transports/tortransport.nim b/libp2p/transports/tortransport.nim index ee39fabca..371cc616e 100644 --- a/libp2p/transports/tortransport.nim +++ b/libp2p/transports/tortransport.nim @@ -269,7 +269,7 @@ proc new*( transports: switch.transports, connManager: switch.connManager, peerStore: switch.peerStore, - dialer: Dialer.new(switch.peerInfo.peerId, switch.connManager, switch.transports, switch.ms, nil), + dialer: Dialer.new(switch.peerInfo.peerId, switch.connManager, switch.peerStore, switch.transports, nil), nameResolver: nil) torSwitch.connManager.peerStore = switch.peerStore diff --git a/libp2p/transports/transport.nim b/libp2p/transports/transport.nim index e2b85832f..a5a651d7e 100644 --- a/libp2p/transports/transport.nim +++ b/libp2p/transports/transport.nim @@ -18,6 +18,7 @@ import chronos, chronicles import ../stream/connection, ../multiaddress, ../multicodec, + ../muxers/muxer, ../upgrademngrs/upgrade logScope: @@ -80,7 +81,7 @@ proc dial*( method upgradeIncoming*( self: Transport, - conn: Connection): Future[void] {.base, gcsafe.} = + conn: Connection): Future[Muxer] {.base, gcsafe.} = ## base upgrade method that the transport uses to perform ## transport specific upgrades ## @@ -90,7 +91,7 @@ method upgradeIncoming*( method upgradeOutgoing*( self: Transport, conn: Connection, - peerId: Opt[PeerId]): Future[Connection] {.base, gcsafe.} = + peerId: Opt[PeerId]): Future[Muxer] {.base, gcsafe.} = ## base upgrade method that the transport uses to perform ## transport specific upgrades ## diff --git a/libp2p/upgrademngrs/muxedupgrade.nim b/libp2p/upgrademngrs/muxedupgrade.nim index a805efa19..7c833e0b9 100644 --- a/libp2p/upgrademngrs/muxedupgrade.nim +++ b/libp2p/upgrademngrs/muxedupgrade.nim @@ -30,35 +30,24 @@ type proc getMuxerByCodec(self: MuxedUpgrade, muxerName: string): MuxerProvider = for m in self.muxers: - if muxerName in m.codecs: + if muxerName == m.codec: return m -proc identify*( - self: MuxedUpgrade, - muxer: Muxer) {.async, gcsafe.} = - # new stream for identify - var stream = await muxer.newStream() - if stream == nil: - return - - try: - await self.identify(stream) - when defined(libp2p_agents_metrics): - muxer.connection.shortAgent = stream.shortAgent - finally: - await stream.closeWithEOF() - proc mux*( self: MuxedUpgrade, - conn: Connection): Future[Muxer] {.async, gcsafe.} = - ## mux outgoing connection + conn: Connection, + direction: Direction): Future[Muxer] {.async, gcsafe.} = + ## mux connection trace "Muxing connection", conn if self.muxers.len == 0: warn "no muxers registered, skipping upgrade flow", conn return - let muxerName = await self.ms.select(conn, self.muxers.mapIt(it.codec)) + let muxerName = + if direction == Out: await self.ms.select(conn, self.muxers.mapIt(it.codec)) + else: await MultistreamSelect.handle(conn, self.muxers.mapIt(it.codec)) + if muxerName.len == 0 or muxerName == "na": debug "no muxer available, early exit", conn return @@ -70,36 +59,23 @@ proc mux*( # install stream handler muxer.streamHandler = self.streamHandler - - self.connManager.storeConn(conn) - - # store it in muxed connections if we have a peer for it - self.connManager.storeMuxer(muxer, muxer.handle()) # store muxer and start read loop - - try: - await self.identify(muxer) - except CatchableError as exc: - # Identify is non-essential, though if it fails, it might indicate that - # the connection was closed already - this will be picked up by the read - # loop - debug "Could not identify connection", conn, msg = exc.msg - + muxer.handler = muxer.handle() return muxer -method upgradeOutgoing*( +proc upgrade( self: MuxedUpgrade, conn: Connection, - peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} = - trace "Upgrading outgoing connection", conn + direction: Direction, + peerId: Opt[PeerId]): Future[Muxer] {.async.} = + trace "Upgrading connection", conn, direction - let sconn = await self.secure(conn, peerId) # secure the connection + let sconn = await self.secure(conn, direction, peerId) # secure the connection if isNil(sconn): raise newException(UpgradeFailedError, "unable to secure connection, stopping upgrade") - let muxer = await self.mux(sconn) # mux it if possible + let muxer = await self.mux(sconn, direction) # mux it if possible if muxer == nil: - # TODO this might be relaxed in the future raise newException(UpgradeFailedError, "a muxer is required for outgoing connections") @@ -111,108 +87,28 @@ method upgradeOutgoing*( raise newException(UpgradeFailedError, "Connection closed or missing peer info, stopping upgrade") - trace "Upgraded outgoing connection", conn, sconn + trace "Upgraded connection", conn, sconn, direction + return muxer - return sconn +method upgradeOutgoing*( + self: MuxedUpgrade, + conn: Connection, + peerId: Opt[PeerId]): Future[Muxer] {.async, gcsafe.} = + return await self.upgrade(conn, Out, peerId) method upgradeIncoming*( self: MuxedUpgrade, - incomingConn: Connection) {.async, gcsafe.} = # noraises - trace "Upgrading incoming connection", incomingConn - let ms = MultistreamSelect.new() - - # secure incoming connections - proc securedHandler(conn: Connection, - proto: string) - {.async, gcsafe, closure.} = - trace "Starting secure handler", conn - let secure = self.secureManagers.filterIt(it.codec == proto)[0] - - var cconn = conn - try: - var sconn = await secure.secure(cconn, false, Opt.none(PeerId)) - if isNil(sconn): - return - - cconn = sconn - # add the muxer - for muxer in self.muxers: - ms.addHandler(muxer.codecs, muxer) - - # handle subsequent secure requests - await ms.handle(cconn) - except CatchableError as exc: - debug "Exception in secure handler during incoming upgrade", msg = exc.msg, conn - if not cconn.isUpgraded: - cconn.upgrade(exc) - finally: - if not isNil(cconn): - await cconn.close() - - trace "Stopped secure handler", conn - - try: - if (await ms.select(incomingConn)): # just handshake - # add the secure handlers - for k in self.secureManagers: - ms.addHandler(k.codec, securedHandler) - - # handle un-secured connections - # we handshaked above, set this ms handler as active - await ms.handle(incomingConn, active = true) - except CatchableError as exc: - debug "Exception upgrading incoming", exc = exc.msg - if not incomingConn.isUpgraded: - incomingConn.upgrade(exc) - finally: - if not isNil(incomingConn): - await incomingConn.close() - -proc muxerHandler( - self: MuxedUpgrade, - muxer: Muxer) {.async, gcsafe.} = - let - conn = muxer.connection - - # store incoming connection - self.connManager.storeConn(conn) - - # store muxer and muxed connection - self.connManager.storeMuxer(muxer) - - try: - await self.identify(muxer) - when defined(libp2p_agents_metrics): - #TODO Passing data between layers is a pain - if muxer.connection of SecureConn: - let secureConn = (SecureConn)muxer.connection - secureConn.stream.shortAgent = muxer.connection.shortAgent - except IdentifyError as exc: - # Identify is non-essential, though if it fails, it might indicate that - # the connection was closed already - this will be picked up by the read - # loop - debug "Could not identify connection", conn, msg = exc.msg - except LPStreamClosedError as exc: - debug "Identify stream closed", conn, msg = exc.msg - except LPStreamEOFError as exc: - debug "Identify stream EOF", conn, msg = exc.msg - except CancelledError as exc: - await muxer.close() - raise exc - except CatchableError as exc: - await muxer.close() - trace "Exception in muxer handler", conn, msg = exc.msg + conn: Connection): Future[Muxer] {.async, gcsafe.} = + return await self.upgrade(conn, In, Opt.none(PeerId)) proc new*( T: type MuxedUpgrade, - identity: Identify, muxers: seq[MuxerProvider], secureManagers: openArray[Secure] = [], connManager: ConnManager, ms: MultistreamSelect): T = let upgrader = T( - identity: identity, muxers: muxers, secureManagers: @secureManagers, connManager: connManager, @@ -231,10 +127,4 @@ proc new*( await conn.closeWithEOF() trace "Stream handler done", conn - for _, val in muxers: - val.streamHandler = upgrader.streamHandler - val.muxerHandler = proc(muxer: Muxer): Future[void] - {.raises: [Defect].} = - upgrader.muxerHandler(muxer) - return upgrader diff --git a/libp2p/upgrademngrs/upgrade.nim b/libp2p/upgrademngrs/upgrade.nim index 6f738577e..3cc3fc75d 100644 --- a/libp2p/upgrademngrs/upgrade.nim +++ b/libp2p/upgrademngrs/upgrade.nim @@ -19,6 +19,7 @@ import pkg/[chronos, chronicles, metrics] import ../stream/connection, ../protocols/secure/secure, ../protocols/identify, + ../muxers/muxer, ../multistream, ../peerstore, ../connmanager, @@ -37,29 +38,31 @@ type Upgrade* = ref object of RootObj ms*: MultistreamSelect - identity*: Identify connManager*: ConnManager secureManagers*: seq[Secure] method upgradeIncoming*( self: Upgrade, - conn: Connection): Future[void] {.base.} = + conn: Connection): Future[Muxer] {.base.} = doAssert(false, "Not implemented!") method upgradeOutgoing*( self: Upgrade, conn: Connection, - peerId: Opt[PeerId]): Future[Connection] {.base.} = + peerId: Opt[PeerId]): Future[Muxer] {.base.} = doAssert(false, "Not implemented!") proc secure*( self: Upgrade, conn: Connection, + direction: Direction, peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} = if self.secureManagers.len <= 0: raise newException(UpgradeFailedError, "No secure managers registered!") - let codec = await self.ms.select(conn, self.secureManagers.mapIt(it.codec)) + let codec = + if direction == Out: await self.ms.select(conn, self.secureManagers.mapIt(it.codec)) + else: await MultistreamSelect.handle(conn, self.secureManagers.mapIt(it.codec)) if codec.len == 0: raise newException(UpgradeFailedError, "Unable to negotiate a secure channel!") @@ -70,30 +73,4 @@ proc secure*( # let's avoid duplicating checks but detect if it fails to do it properly doAssert(secureProtocol.len > 0) - return await secureProtocol[0].secure(conn, true, peerId) - -proc identify*( - self: Upgrade, - conn: Connection) {.async, gcsafe.} = - ## identify the connection - - if (await self.ms.select(conn, self.identity.codec)): - let - info = await self.identity.identify(conn, conn.peerId) - peerStore = self.connManager.peerStore - - if info.pubkey.isNone and isNil(conn): - raise newException(UpgradeFailedError, - "no public key provided and no existing peer identity found") - - conn.peerId = info.peerId - - when defined(libp2p_agents_metrics): - conn.shortAgent = "unknown" - if info.agentVersion.isSome and info.agentVersion.get().len > 0: - let shortAgent = info.agentVersion.get().split("/")[0].safeToLowerAscii() - if shortAgent.isOk() and KnownLibP2PAgentsSeq.contains(shortAgent.get()): - conn.shortAgent = shortAgent.get() - - peerStore.updatePeerInfo(info) - trace "identified remote peer", conn, peerId = shortLog(conn.peerId) + return await secureProtocol[0].secure(conn, direction == Out, peerId) diff --git a/tests/config.nims b/tests/config.nims index d108df249..14179516e 100644 --- a/tests/config.nims +++ b/tests/config.nims @@ -8,6 +8,7 @@ import strutils, os --d:libp2p_protobuf_metrics --d:libp2p_network_protocols_metrics --d:libp2p_mplex_metrics +--d:unittestPrintTime --skipParentCfg # Only add chronicles param if the diff --git a/tests/pubsub/testgossipinternal.nim b/tests/pubsub/testgossipinternal.nim index aa16b03d0..a48caad8a 100644 --- a/tests/pubsub/testgossipinternal.nim +++ b/tests/pubsub/testgossipinternal.nim @@ -9,6 +9,7 @@ import ../../libp2p/errors import ../../libp2p/crypto/crypto import ../../libp2p/stream/bufferstream import ../../libp2p/switch +import ../../libp2p/muxers/muxer import ../helpers @@ -495,7 +496,7 @@ suite "GossipSub internal": peer.handler = handler peer.appScore = gossipSub.parameters.graylistThreshold - 1 gossipSub.gossipsub.mgetOrPut(topic, initHashSet[PubSubPeer]()).incl(peer) - gossipSub.switch.connManager.storeConn(conn) + gossipSub.switch.connManager.storeMuxer(Muxer(connection: conn)) gossipSub.updateScores() diff --git a/tests/pubsub/testgossipsub.nim b/tests/pubsub/testgossipsub.nim index 16a71eae1..396bccbab 100644 --- a/tests/pubsub/testgossipsub.nim +++ b/tests/pubsub/testgossipsub.nim @@ -107,11 +107,7 @@ suite "GossipSub": nodes[0].subscribe("foobar", handler) nodes[1].subscribe("foobar", handler) - var subs: seq[Future[void]] - subs &= waitSub(nodes[1], nodes[0], "foobar") - subs &= waitSub(nodes[0], nodes[1], "foobar") - - await allFuturesThrowing(subs) + await waitSubGraph(nodes, "foobar") let gossip1 = GossipSub(nodes[0]) let gossip2 = GossipSub(nodes[1]) @@ -157,11 +153,7 @@ suite "GossipSub": nodes[0].subscribe("foobar", handler) nodes[1].subscribe("foobar", handler) - var subs: seq[Future[void]] - subs &= waitSub(nodes[1], nodes[0], "foobar") - subs &= waitSub(nodes[0], nodes[1], "foobar") - - await allFuturesThrowing(subs) + await waitSubGraph(nodes, "foobar") let gossip1 = GossipSub(nodes[0]) let gossip2 = GossipSub(nodes[1]) @@ -424,8 +416,6 @@ suite "GossipSub": await passed.wait(2.seconds) - trace "test done, stopping..." - await allFuturesThrowing( nodes[0].switch.stop(), nodes[1].switch.stop() @@ -452,21 +442,23 @@ suite "GossipSub": nodes[1].switch.start(), ) + GossipSub(nodes[1]).parameters.d = 0 + GossipSub(nodes[1]).parameters.dHigh = 0 + GossipSub(nodes[1]).parameters.dLow = 0 + await subscribeNodes(nodes) - nodes[1].subscribe("foobar", handler) nodes[0].subscribe("foobar", handler) - await waitSub(nodes[0], nodes[1], "foobar") - await waitSub(nodes[1], nodes[0], "foobar") - - nodes[0].unsubscribe("foobar", handler) + nodes[1].subscribe("foobar", handler) let gsNode = GossipSub(nodes[1]) - checkExpiring: gsNode.mesh.getOrDefault("foobar").len == 0 - - nodes[0].subscribe("foobar", handler) - - check GossipSub(nodes[0]).mesh.getOrDefault("foobar").len == 0 + checkExpiring: + gsNode.mesh.getOrDefault("foobar").len == 0 and + GossipSub(nodes[0]).mesh.getOrDefault("foobar").len == 0 and + ( + GossipSub(nodes[0]).gossipsub.getOrDefault("foobar").len == 1 or + GossipSub(nodes[0]).fanout.getOrDefault("foobar").len == 1 + ) tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 1 @@ -532,8 +524,8 @@ suite "GossipSub": asyncTest "e2e - GossipSub should not send to source & peers who already seen": # 3 nodes: A, B, C - # A publishes, B relays, C is having a long validation - # so C should not send to anyone + # A publishes, C relays, B is having a long validation + # so B should not send to anyone let nodes = generateNodes( @@ -566,10 +558,7 @@ suite "GossipSub": nodes[0].subscribe("foobar", handlerA) nodes[1].subscribe("foobar", handlerB) nodes[2].subscribe("foobar", handlerC) - await waitSub(nodes[0], nodes[1], "foobar") - await waitSub(nodes[0], nodes[2], "foobar") - await waitSub(nodes[2], nodes[1], "foobar") - await waitSub(nodes[1], nodes[2], "foobar") + await waitSubGraph(nodes, "foobar") var gossip1: GossipSub = GossipSub(nodes[0]) var gossip2: GossipSub = GossipSub(nodes[1]) @@ -587,7 +576,11 @@ suite "GossipSub": nodes[1].addValidator("foobar", slowValidator) - tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 1 + checkExpiring( + gossip1.mesh.getOrDefault("foobar").len == 2 and + gossip2.mesh.getOrDefault("foobar").len == 2 and + gossip3.mesh.getOrDefault("foobar").len == 2) + tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 2 await bFinished @@ -629,7 +622,7 @@ suite "GossipSub": tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 1 - check await passed + check await passed.wait(10.seconds) check: "foobar" in gossip1.gossipsub diff --git a/tests/pubsub/utils.nim b/tests/pubsub/utils.nim index 095c68c45..6ac49b9b8 100644 --- a/tests/pubsub/utils.nim +++ b/tests/pubsub/utils.nim @@ -132,13 +132,17 @@ proc waitSubGraph*(nodes: seq[PubSub], key: string) {.async, gcsafe.} = seen: HashSet[PeerId] for n in nodes: nodesMesh[n.peerInfo.peerId] = toSeq(GossipSub(n).mesh.getOrDefault(key).items()).mapIt(it.peerId) - proc explore(p: PeerId) = - if p in seen: return - seen.incl(p) - for peer in nodesMesh.getOrDefault(p): - explore(peer) - explore(nodes[0].peerInfo.peerId) - if seen.len == nodes.len: return + var ok = 0 + for n in nodes: + seen.clear() + proc explore(p: PeerId) = + if p in seen: return + seen.incl(p) + for peer in nodesMesh.getOrDefault(p): + explore(peer) + explore(n.peerInfo.peerId) + if seen.len == nodes.len: ok.inc() + if ok == nodes.len: return trace "waitSubGraph sleeping..." await sleepAsync(5.milliseconds) diff --git a/tests/testconnmngr.nim b/tests/testconnmngr.nim index f15215766..72f2f403c 100644 --- a/tests/testconnmngr.nim +++ b/tests/testconnmngr.nim @@ -10,8 +10,8 @@ import ../libp2p/[connmanager, import helpers -proc getConnection(peerId: PeerId, dir: Direction = Direction.In): Connection = - return Connection.new(peerId, dir, Opt.none(MultiAddress)) +proc getMuxer(peerId: PeerId, dir: Direction = Direction.In): Muxer = + return Muxer(connection: Connection.new(peerId, dir, Opt.none(MultiAddress))) type TestMuxer = ref object of Muxer @@ -22,71 +22,55 @@ method newStream*( name: string = "", lazy: bool = false): Future[Connection] {.async, gcsafe.} = - result = getConnection(m.peerId, Direction.Out) + result = Connection.new(m.peerId, Direction.Out, Opt.none(MultiAddress)) suite "Connection Manager": teardown: checkTrackers() - asyncTest "add and retrieve a connection": + asyncTest "add and retrieve a muxer": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = getConnection(peerId) + let mux = getMuxer(peerId) - connMngr.storeConn(conn) - check conn in connMngr + connMngr.storeMuxer(mux) + check mux in connMngr - let peerConn = connMngr.selectConn(peerId) - check peerConn == conn - check peerConn.dir == Direction.In + let peerMux = connMngr.selectMuxer(peerId) + check peerMux == mux + check peerMux.connection.dir == Direction.In await connMngr.close() asyncTest "shouldn't allow a closed connection": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = getConnection(peerId) - await conn.close() + let mux = getMuxer(peerId) + await mux.connection.close() expect CatchableError: - connMngr.storeConn(conn) + connMngr.storeMuxer(mux) await connMngr.close() asyncTest "shouldn't allow an EOFed connection": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = getConnection(peerId) - conn.isEof = true + let mux = getMuxer(peerId) + mux.connection.isEof = true expect CatchableError: - connMngr.storeConn(conn) + connMngr.storeMuxer(mux) - await conn.close() + await mux.close() await connMngr.close() - asyncTest "add and retrieve a muxer": + asyncTest "shouldn't allow a muxer with no connection": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = getConnection(peerId) - let muxer = new Muxer - muxer.connection = conn - - connMngr.storeConn(conn) - connMngr.storeMuxer(muxer) - check muxer in connMngr - - let peerMuxer = connMngr.selectMuxer(conn) - check peerMuxer == muxer - - await connMngr.close() - - asyncTest "shouldn't allow a muxer for an untracked connection": - let connMngr = ConnManager.new() - let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = getConnection(peerId) - let muxer = new Muxer - muxer.connection = conn + let muxer = getMuxer(peerId) + let conn = muxer.connection + muxer.connection = nil expect CatchableError: connMngr.storeMuxer(muxer) @@ -99,33 +83,34 @@ suite "Connection Manager": # This would work with 1 as well cause of a bug in connmanager that will get fixed soon let connMngr = ConnManager.new(maxConnsPerPeer = 2) let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn1 = getConnection(peerId, Direction.Out) - let conn2 = getConnection(peerId) + let mux1 = getMuxer(peerId, Direction.Out) + let mux2 = getMuxer(peerId) - connMngr.storeConn(conn1) - connMngr.storeConn(conn2) - check conn1 in connMngr - check conn2 in connMngr + connMngr.storeMuxer(mux1) + connMngr.storeMuxer(mux2) + check mux1 in connMngr + check mux2 in connMngr - let outConn = connMngr.selectConn(peerId, Direction.Out) - let inConn = connMngr.selectConn(peerId, Direction.In) + let outMux = connMngr.selectMuxer(peerId, Direction.Out) + let inMux = connMngr.selectMuxer(peerId, Direction.In) - check outConn != inConn - check outConn.dir == Direction.Out - check inConn.dir == Direction.In + check outMux != inMux + check outMux == mux1 + check inMux == mux2 + check outMux.connection.dir == Direction.Out + check inMux.connection.dir == Direction.In await connMngr.close() asyncTest "get muxed stream for peer": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = getConnection(peerId) let muxer = new TestMuxer + let connection = Connection.new(peerId, Direction.In, Opt.none(MultiAddress)) muxer.peerId = peerId - muxer.connection = conn + muxer.connection = connection - connMngr.storeConn(conn) connMngr.storeMuxer(muxer) check muxer in connMngr @@ -134,18 +119,18 @@ suite "Connection Manager": check stream.peerId == peerId await connMngr.close() + await connection.close() await stream.close() asyncTest "get stream from directed connection": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = getConnection(peerId) let muxer = new TestMuxer + let connection = Connection.new(peerId, Direction.In, Opt.none(MultiAddress)) muxer.peerId = peerId - muxer.connection = conn + muxer.connection = connection - connMngr.storeConn(conn) connMngr.storeMuxer(muxer) check muxer in connMngr @@ -156,57 +141,37 @@ suite "Connection Manager": await connMngr.close() await stream1.close() - - asyncTest "get stream from any connection": - let connMngr = ConnManager.new() - let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = getConnection(peerId) - - let muxer = new TestMuxer - muxer.peerId = peerId - muxer.connection = conn - - connMngr.storeConn(conn) - connMngr.storeMuxer(muxer) - check muxer in connMngr - - let stream = await connMngr.getStream(conn) - check not(isNil(stream)) - - await connMngr.close() - await stream.close() + await connection.close() asyncTest "should raise on too many connections": let connMngr = ConnManager.new(maxConnsPerPeer = 0) let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - connMngr.storeConn(getConnection(peerId)) + connMngr.storeMuxer(getMuxer(peerId)) - let conns = @[ - getConnection(peerId), - getConnection(peerId)] + let muxs = @[getMuxer(peerId)] expect TooManyConnectionsError: - connMngr.storeConn(conns[0]) + connMngr.storeMuxer(muxs[0]) await connMngr.close() await allFuturesThrowing( - allFutures(conns.mapIt( it.close() ))) + allFutures(muxs.mapIt( it.close() ))) asyncTest "expect connection from peer": # FIXME This should be 1 instead of 0, it will get fixed soon let connMngr = ConnManager.new(maxConnsPerPeer = 0) let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - connMngr.storeConn(getConnection(peerId)) + connMngr.storeMuxer(getMuxer(peerId)) - let conns = @[ - getConnection(peerId), - getConnection(peerId)] + let muxs = @[ + getMuxer(peerId), + getMuxer(peerId)] expect TooManyConnectionsError: - connMngr.storeConn(conns[0]) + connMngr.storeMuxer(muxs[0]) let waitedConn1 = connMngr.expectConnection(peerId, In) @@ -217,38 +182,32 @@ suite "Connection Manager": let waitedConn2 = connMngr.expectConnection(peerId, In) waitedConn3 = connMngr.expectConnection(PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), In) - conn = getConnection(peerId) - connMngr.storeConn(conn) + conn = getMuxer(peerId) + connMngr.storeMuxer(conn) check (await waitedConn2) == conn expect TooManyConnectionsError: - connMngr.storeConn(conns[1]) + connMngr.storeMuxer(muxs[1]) await connMngr.close() checkExpiring: waitedConn3.cancelled() await allFuturesThrowing( - allFutures(conns.mapIt( it.close() ))) + allFutures(muxs.mapIt( it.close() ))) asyncTest "cleanup on connection close": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = getConnection(peerId) - let muxer = new Muxer + let muxer = getMuxer(peerId) - muxer.connection = conn - connMngr.storeConn(conn) connMngr.storeMuxer(muxer) - check conn in connMngr check muxer in connMngr - await conn.close() - await sleepAsync(10.millis) + await muxer.close() - check conn notin connMngr - check muxer notin connMngr + checkExpiring: muxer notin connMngr await connMngr.close() @@ -261,23 +220,19 @@ suite "Connection Manager": Direction.In else: Direction.Out - let conn = getConnection(peerId, dir) - let muxer = new Muxer - muxer.connection = conn + let muxer = getMuxer(peerId, dir) - connMngr.storeConn(conn) connMngr.storeMuxer(muxer) - check conn in connMngr check muxer in connMngr - check not(isNil(connMngr.selectConn(peerId, dir))) + check not(isNil(connMngr.selectMuxer(peerId, dir))) check peerId in connMngr await connMngr.dropPeer(peerId) - check peerId notin connMngr - check isNil(connMngr.selectConn(peerId, Direction.In)) - check isNil(connMngr.selectConn(peerId, Direction.Out)) + checkExpiring: peerId notin connMngr + check isNil(connMngr.selectMuxer(peerId, Direction.In)) + check isNil(connMngr.selectMuxer(peerId, Direction.Out)) await connMngr.close() @@ -363,7 +318,6 @@ suite "Connection Manager": asyncTest "track incoming max connections limits - fail on outgoing": let connMngr = ConnManager.new(maxIn = 3) - var conns: seq[Connection] for i in 0..<3: check await connMngr.getIncomingSlot().withTimeout(10.millis) @@ -376,7 +330,6 @@ suite "Connection Manager": asyncTest "allow force dial": let connMngr = ConnManager.new(maxConnections = 2) - var conns: seq[Connection] for i in 0..<3: discard connMngr.getOutgoingSlot(true) @@ -389,17 +342,17 @@ suite "Connection Manager": asyncTest "release slot on connection end": let connMngr = ConnManager.new(maxConnections = 3) - var conns: seq[Connection] + var muxs: seq[Muxer] for i in 0..<3: let slot = connMngr.getOutgoingSlot() - let conn = - getConnection( + let muxer = + getMuxer( PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), Direction.In) - slot.trackConnection(conn) - conns.add(conn) + slot.trackMuxer(muxer) + muxs.add(muxer) # should be full now let incomingSlot = connMngr.getIncomingSlot() @@ -407,7 +360,7 @@ suite "Connection Manager": check (await incomingSlot.withTimeout(10.millis)) == false await allFuturesThrowing( - allFutures(conns.mapIt( it.close() ))) + allFutures(muxs.mapIt( it.close() ))) check await incomingSlot.withTimeout(10.millis) diff --git a/tests/testidentify.nim b/tests/testidentify.nim index 9fa1f4ddd..8b43237f1 100644 --- a/tests/testidentify.nim +++ b/tests/testidentify.nim @@ -177,7 +177,7 @@ suite "Identify": check: switch1.peerStore[AddressBook][switch2.peerInfo.peerId] == switch2.peerInfo.addrs switch2.peerStore[AddressBook][switch1.peerInfo.peerId] == switch1.peerInfo.addrs - + switch1.peerStore[KeyBook][switch2.peerInfo.peerId] == switch2.peerInfo.publicKey switch2.peerStore[KeyBook][switch1.peerInfo.peerId] == switch1.peerInfo.publicKey diff --git a/tests/testmultistream.nim b/tests/testmultistream.nim index a29993d53..1f10a9255 100644 --- a/tests/testmultistream.nim +++ b/tests/testmultistream.nim @@ -224,8 +224,7 @@ suite "Multistream select": var conn: Connection = nil proc testNaHandler(msg: string): Future[void] {.async, gcsafe.} = - echo msg - check msg == Na + check msg == "\x03na\n" await conn.close() conn = newTestNaStream(testNaHandler) diff --git a/tests/testnoise.nim b/tests/testnoise.nim index ce0b51b00..1019eb1d7 100644 --- a/tests/testnoise.nim +++ b/tests/testnoise.nim @@ -67,6 +67,7 @@ proc createSwitch(ma: MultiAddress; outgoing: bool, secio: bool = false): (Switc let identify = Identify.new(peerInfo) + peerStore = PeerStore.new(identify) mplexProvider = MuxerProvider.new(createMplex, MplexCodec) muxers = @[mplexProvider] secureManagers = if secio: @@ -75,16 +76,16 @@ proc createSwitch(ma: MultiAddress; outgoing: bool, secio: bool = false): (Switc [Secure(Noise.new(rng, privateKey, outgoing = outgoing))] connManager = ConnManager.new() ms = MultistreamSelect.new() - muxedUpgrade = MuxedUpgrade.new(identify, muxers, secureManagers, connManager, ms) + muxedUpgrade = MuxedUpgrade.new(muxers, secureManagers, connManager, ms) transports = @[Transport(TcpTransport.new(upgrade = muxedUpgrade))] let switch = newSwitch( peerInfo, transports, - identify, secureManagers, connManager, - ms) + ms, + peerStore) result = (switch, peerInfo) suite "Noise": diff --git a/tests/testpeerstore.nim b/tests/testpeerstore.nim index 74477319b..b6ce7cdd5 100644 --- a/tests/testpeerstore.nim +++ b/tests/testpeerstore.nim @@ -96,7 +96,7 @@ suite "PeerStore": toSeq(values(addressBook.book))[0] == @[multiaddr1, multiaddr2] test "Pruner - no capacity": - let peerStore = PeerStore.new(capacity = 0) + let peerStore = PeerStore.new(nil, capacity = 0) peerStore[AgentBook][peerId1] = "gds" peerStore.cleanup(peerId1) @@ -104,7 +104,7 @@ suite "PeerStore": check peerId1 notin peerStore[AgentBook] test "Pruner - FIFO": - let peerStore = PeerStore.new(capacity = 1) + let peerStore = PeerStore.new(nil, capacity = 1) peerStore[AgentBook][peerId1] = "gds" peerStore[AgentBook][peerId2] = "gds" peerStore.cleanup(peerId2) @@ -114,7 +114,7 @@ suite "PeerStore": peerId2 notin peerStore[AgentBook] test "Pruner - regular capacity": - var peerStore = PeerStore.new(capacity = 20) + var peerStore = PeerStore.new(nil, capacity = 20) for i in 0..<30: let randomPeerId = PeerId.init(KeyPair.random(ECDSA, rng[]).get().pubkey).get() @@ -124,7 +124,7 @@ suite "PeerStore": check peerStore[AgentBook].len == 20 test "Pruner - infinite capacity": - var peerStore = PeerStore.new(capacity = -1) + var peerStore = PeerStore.new(nil, capacity = -1) for i in 0..<30: let randomPeerId = PeerId.init(KeyPair.random(ECDSA, rng[]).get().pubkey).get()