diff --git a/libp2p/multistream.nim b/libp2p/multistream.nim index 3070123d7..15cdad9b2 100644 --- a/libp2p/multistream.nim +++ b/libp2p/multistream.nim @@ -9,7 +9,7 @@ {.push raises: [Defect].} -import std/[strutils] +import std/[strutils, sequtils] import chronos, chronicles, stew/byteutils import stream/connection, protocols/protocol @@ -209,3 +209,9 @@ proc addHandler*(m: MultistreamSelect, m.handlers.add(HandlerHolder(protos: @[codec], protocol: protocol, match: matcher)) + +proc start*(m: MultistreamSelect) {.async.} = + await allFutures(m.handlers.mapIt(it.protocol.start())) + +proc stop*(m: MultistreamSelect) {.async.} = + await allFutures(m.handlers.mapIt(it.protocol.stop())) diff --git a/libp2p/protocols/protocol.nim b/libp2p/protocols/protocol.nim index aa734f46a..edba47647 100644 --- a/libp2p/protocols/protocol.nim +++ b/libp2p/protocols/protocol.nim @@ -22,8 +22,12 @@ type LPProtocol* = ref object of RootObj codecs*: seq[string] handler*: LPProtoHandler ## this handler gets invoked by the protocol negotiator + started*: bool method init*(p: LPProtocol) {.base, gcsafe.} = discard +method start*(p: LPProtocol) {.async, base.} = p.started = true +method stop*(p: LPProtocol) {.async, base.} = p.started = false + func codec*(p: LPProtocol): string = assert(p.codecs.len > 0, "Codecs sequence was empty!") diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 86f3df419..931cd57ba 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -562,22 +562,28 @@ method publish*(g: GossipSub, return peers.len +proc maintainDirectPeer(g: GossipSub, id: PeerId, addrs: seq[MultiAddress]) {.async.} = + let peer = g.peers.getOrDefault(id) + if isNil(peer): + trace "Attempting to dial a direct peer", peer = id + try: + await g.switch.connect(id, addrs) + # populate the peer after it's connected + discard g.getOrCreatePeer(id, g.codecs) + except CancelledError as exc: + trace "Direct peer dial canceled" + raise exc + except CatchableError as exc: + debug "Direct peer error dialing", msg = exc.msg + +proc addDirectPeer*(g: GossipSub, id: PeerId, addrs: seq[MultiAddress]) {.async.} = + g.parameters.directPeers[id] = addrs + await g.maintainDirectPeer(id, addrs) + proc maintainDirectPeers(g: GossipSub) {.async.} = heartbeat "GossipSub DirectPeers", 1.minutes: for id, addrs in g.parameters.directPeers: - let peer = g.peers.getOrDefault(id) - if isNil(peer): - trace "Attempting to dial a direct peer", peer = id - try: - # dial, internally connection will be stored - let _ = await g.switch.dial(id, addrs, g.codecs) - # populate the peer after it's connected - discard g.getOrCreatePeer(id, g.codecs) - except CancelledError as exc: - trace "Direct peer dial canceled" - raise exc - except CatchableError as exc: - debug "Direct peer error dialing", msg = exc.msg + await g.addDirectPeer(id, addrs) method start*(g: GossipSub) {.async.} = trace "gossipsub start" @@ -589,9 +595,11 @@ method start*(g: GossipSub) {.async.} = g.heartbeatFut = g.heartbeat() g.scoringHeartbeatFut = g.scoringHeartbeat() g.directPeersLoop = g.maintainDirectPeers() + g.started = true method stop*(g: GossipSub) {.async.} = trace "gossipsub stop" + g.started = false if g.heartbeatFut.isNil: warn "Stopping gossipsub without starting it" return diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index 1d6819ef5..8ddea7389 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -488,14 +488,6 @@ method initPubSub*(p: PubSub) if p.msgIdProvider == nil: p.msgIdProvider = defaultMsgIdProvider -method start*(p: PubSub) {.async, base.} = - ## start pubsub - discard - -method stop*(p: PubSub) {.async, base.} = - ## stopt pubsub - discard - method addValidator*(p: PubSub, topic: varargs[string], hook: ValidatorHandler) {.base.} = diff --git a/libp2p/switch.nim b/libp2p/switch.nim index c8957703d..256a6ded1 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -65,6 +65,7 @@ type dialer*: Dial peerStore*: PeerStore nameResolver*: NameResolver + started: bool proc addConnEventHandler*(s: Switch, handler: ConnEventHandler, @@ -144,6 +145,9 @@ proc mount*[T: LPProtocol](s: Switch, proto: T, matcher: Matcher = nil) raise newException(LPError, "Protocol has to define a codec string") + if s.started and not proto.started: + raise newException(LPError, "Protocol not started") + s.ms.addHandler(proto.codecs, proto, matcher) s.peerInfo.protocols.add(proto.codec) @@ -216,6 +220,7 @@ proc accept(s: Switch, transport: Transport) {.async.} = # noraises proc stop*(s: Switch) {.async.} = trace "Stopping switch" + s.started = false # close and cleanup all connections await s.connManager.close() @@ -239,6 +244,8 @@ proc stop*(s: Switch) {.async.} = if not a.finished: a.cancel() + await s.ms.stop() + trace "Switch stopped" proc start*(s: Switch) {.async, gcsafe.} = @@ -272,6 +279,10 @@ proc start*(s: Switch) {.async, gcsafe.} = s.peerInfo.update() + await s.ms.start() + + s.started = true + debug "Started libp2p node", peer = s.peerInfo proc newSwitch*(peerInfo: PeerInfo, diff --git a/tests/pubsub/testfloodsub.nim b/tests/pubsub/testfloodsub.nim index 38d00d9cd..78513281c 100644 --- a/tests/pubsub/testfloodsub.nim +++ b/tests/pubsub/testfloodsub.nim @@ -54,13 +54,6 @@ suite "FloodSub": nodes[1].switch.start(), ) - # start pubsub - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - )) - await subscribeNodes(nodes) nodes[1].subscribe("foobar", handler) @@ -74,11 +67,6 @@ suite "FloodSub": nodes[1].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop() - ) - await allFuturesThrowing(nodesFut.concat()) asyncTest "FloodSub basic publish/subscribe B -> A": @@ -96,12 +84,6 @@ suite "FloodSub": nodes[1].switch.start(), ) - # start pubsubcon - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - )) await subscribeNodes(nodes) @@ -117,11 +99,6 @@ suite "FloodSub": nodes[1].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop() - ) - await allFuturesThrowing(nodesFut) asyncTest "FloodSub validation should succeed": @@ -139,13 +116,6 @@ suite "FloodSub": nodes[1].switch.start(), ) - # start pubsubcon - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - )) - await subscribeNodes(nodes) nodes[1].subscribe("foobar", handler) @@ -168,11 +138,6 @@ suite "FloodSub": nodes[1].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop() - ) - await allFuturesThrowing(nodesFut) asyncTest "FloodSub validation should fail": @@ -188,13 +153,6 @@ suite "FloodSub": nodes[1].switch.start(), ) - # start pubsubcon - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - )) - await subscribeNodes(nodes) nodes[1].subscribe("foobar", handler) await waitSub(nodes[0], nodes[1], "foobar") @@ -214,11 +172,6 @@ suite "FloodSub": nodes[1].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop() - ) - await allFuturesThrowing(nodesFut) asyncTest "FloodSub validation one fails and one succeeds": @@ -236,13 +189,6 @@ suite "FloodSub": nodes[1].switch.start(), ) - # start pubsubcon - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - )) - await subscribeNodes(nodes) nodes[1].subscribe("foo", handler) await waitSub(nodes[0], nodes[1], "foo") @@ -266,11 +212,6 @@ suite "FloodSub": nodes[1].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop() - ) - await allFuturesThrowing(nodesFut) asyncTest "FloodSub multiple peers, no self trigger": @@ -296,7 +237,6 @@ suite "FloodSub": nodes = generateNodes(runs, triggerSelf = false) nodesFut = nodes.mapIt(it.switch.start()) - await allFuturesThrowing(nodes.mapIt(it.start())) await subscribeNodes(nodes) for i in 0.. B": @@ -511,13 +427,6 @@ suite "GossipSub": nodes[1].switch.start(), ) - # start pubsub - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - )) - await subscribeNodes(nodes) nodes[1].subscribe("foobar", handler) @@ -549,19 +458,11 @@ suite "GossipSub": trace "test done, stopping..." - await nodes[0].stop() - await nodes[1].stop() - await allFuturesThrowing( nodes[0].switch.stop(), nodes[1].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop() - ) - await allFuturesThrowing(nodesFut.concat()) check observed == 2 @@ -583,13 +484,6 @@ suite "GossipSub": nodes[1].switch.start(), ) - # start pubsub - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - )) - await subscribeNodes(nodes) nodes[1].subscribe("foobar", handler) @@ -616,19 +510,11 @@ suite "GossipSub": trace "test done, stopping..." - await nodes[0].stop() - await nodes[1].stop() - await allFuturesThrowing( nodes[0].switch.stop(), nodes[1].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop() - ) - await allFuturesThrowing(nodesFut.concat()) asyncTest "e2e - GossipSub send over mesh A -> B": @@ -648,13 +534,6 @@ suite "GossipSub": nodes[1].switch.start(), ) - # start pubsub - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - )) - await subscribeNodes(nodes) nodes[0].subscribe("foobar", handler) @@ -681,11 +560,6 @@ suite "GossipSub": nodes[1].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop() - ) - await allFuturesThrowing(nodesFut.concat()) asyncTest "e2e - GossipSub should not send to source & peers who already seen": @@ -705,14 +579,6 @@ suite "GossipSub": nodes[2].switch.start(), ) - # start pubsub - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - nodes[2].start(), - )) - await subscribeNodes(nodes) var cRelayed: Future[void] = newFuture[void]() @@ -763,12 +629,6 @@ suite "GossipSub": nodes[2].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop(), - nodes[2].stop() - ) - await allFuturesThrowing(nodesFut.concat()) asyncTest "e2e - GossipSub send over floodPublish A -> B": @@ -788,13 +648,6 @@ suite "GossipSub": nodes[1].switch.start(), ) - # start pubsub - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - )) - var gossip1: GossipSub = GossipSub(nodes[0]) gossip1.parameters.floodPublish = true var gossip2: GossipSub = GossipSub(nodes[1]) @@ -821,11 +674,6 @@ suite "GossipSub": nodes[1].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop() - ) - await allFuturesThrowing(nodesFut.concat()) asyncTest "e2e - GossipSub with multiple peers": @@ -835,7 +683,6 @@ suite "GossipSub": nodes = generateNodes(runs, gossip = true, triggerSelf = true) nodesFut = nodes.mapIt(it.switch.start()) - await allFuturesThrowing(nodes.mapIt(it.start())) await subscribeNodes(nodes) var seen: Table[string, int] @@ -875,7 +722,6 @@ suite "GossipSub": await allFuturesThrowing( nodes.mapIt( allFutures( - it.stop(), it.switch.stop()))) await allFuturesThrowing(nodesFut) @@ -887,7 +733,6 @@ suite "GossipSub": nodes = generateNodes(runs, gossip = true, triggerSelf = true) nodesFut = nodes.mapIt(it.switch.start()) - await allFuturesThrowing(nodes.mapIt(it.start())) await subscribeSparseNodes(nodes) var seen: Table[string, int] @@ -928,7 +773,6 @@ suite "GossipSub": await allFuturesThrowing( nodes.mapIt( allFutures( - it.stop(), it.switch.stop()))) await allFuturesThrowing(nodesFut) @@ -956,14 +800,6 @@ suite "GossipSub": nodes[2].switch.start(), ) - # start pubsub - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - nodes[2].start(), - )) - var gossip0 = GossipSub(nodes[0]) gossip1 = GossipSub(nodes[1]) @@ -997,11 +833,4 @@ suite "GossipSub": nodes[2].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop(), - nodes[2].stop() - ) - await allFuturesThrowing(nodesFut.concat()) - diff --git a/tests/pubsub/testgossipsub2.nim b/tests/pubsub/testgossipsub2.nim index 70d44b865..fa19447f9 100644 --- a/tests/pubsub/testgossipsub2.nim +++ b/tests/pubsub/testgossipsub2.nim @@ -77,7 +77,6 @@ suite "GossipSub": nodes = generateNodes(runs, gossip = true, triggerSelf = true) nodesFut = nodes.mapIt(it.switch.start()) - await allFuturesThrowing(nodes.mapIt(it.start())) await subscribeSparseNodes(nodes) var seen: Table[string, int] @@ -120,7 +119,6 @@ suite "GossipSub": await allFuturesThrowing( nodes.mapIt( allFutures( - it.stop(), it.switch.stop()))) await allFuturesThrowing(nodesFut) @@ -140,13 +138,6 @@ suite "GossipSub": nodes[1].switch.start(), ) - # start pubsub - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - )) - # We must subscribe before setting the validator nodes[0].subscribe("foobar", handler) @@ -174,36 +165,15 @@ suite "GossipSub": nodes[1].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop() - ) - await allFuturesThrowing(nodesFut.concat()) asyncTest "GossipSub test directPeers": - - let - nodes = generateNodes(2, gossip = true) - - # start switches - nodesFut = await allFinished( - nodes[0].switch.start(), - nodes[1].switch.start(), - ) - - var gossip = GossipSub(nodes[0]) - gossip.parameters.directPeers[nodes[1].switch.peerInfo.peerId] = nodes[1].switch.peerInfo.addrs - - # start pubsub - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - )) + let nodes = generateNodes(2, gossip = true) + await allFutures(nodes[0].switch.start(), nodes[1].switch.start()) + await GossipSub(nodes[0]).addDirectPeer(nodes[1].switch.peerInfo.peerId, nodes[1].switch.peerInfo.addrs) let invalidDetected = newFuture[void]() - gossip.subscriptionValidator = + GossipSub(nodes[0]).subscriptionValidator = proc(topic: string): bool = if topic == "foobar": try: @@ -227,13 +197,6 @@ suite "GossipSub": nodes[1].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop() - ) - - await allFuturesThrowing(nodesFut.concat()) - asyncTest "GossipSub directPeers: always forward messages": let nodes = generateNodes(2, gossip = true) @@ -244,15 +207,8 @@ suite "GossipSub": nodes[1].switch.start(), ) - GossipSub(nodes[0]).parameters.directPeers[nodes[1].switch.peerInfo.peerId] = nodes[1].switch.peerInfo.addrs - GossipSub(nodes[1]).parameters.directPeers[nodes[0].switch.peerInfo.peerId] = nodes[0].switch.peerInfo.addrs - - # start pubsub - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - )) + await GossipSub(nodes[0]).addDirectPeer(nodes[1].switch.peerInfo.peerId, nodes[1].switch.peerInfo.addrs) + await GossipSub(nodes[1]).addDirectPeer(nodes[0].switch.peerInfo.peerId, nodes[0].switch.peerInfo.addrs) var handlerFut = newFuture[void]() proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = @@ -275,11 +231,6 @@ suite "GossipSub": nodes[1].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop() - ) - await allFuturesThrowing(nodesFut.concat()) asyncTest "GossipSub directPeers: don't kick direct peer with low score": @@ -292,19 +243,12 @@ suite "GossipSub": nodes[1].switch.start(), ) - GossipSub(nodes[0]).parameters.directPeers[nodes[1].switch.peerInfo.peerId] = nodes[1].switch.peerInfo.addrs - GossipSub(nodes[1]).parameters.directPeers[nodes[0].switch.peerInfo.peerId] = nodes[0].switch.peerInfo.addrs + await GossipSub(nodes[0]).addDirectPeer(nodes[1].switch.peerInfo.peerId, nodes[1].switch.peerInfo.addrs) + await GossipSub(nodes[1]).addDirectPeer(nodes[0].switch.peerInfo.peerId, nodes[0].switch.peerInfo.addrs) GossipSub(nodes[1]).parameters.disconnectBadPeers = true GossipSub(nodes[1]).parameters.graylistThreshold = 100000 - # start pubsub - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - )) - var handlerFut = newFuture[void]() proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = check topic == "foobar" @@ -334,11 +278,6 @@ suite "GossipSub": nodes[1].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop() - ) - await allFuturesThrowing(nodesFut.concat()) asyncTest "GossipsSub peers disconnections mechanics": @@ -348,7 +287,6 @@ suite "GossipSub": nodes = generateNodes(runs, gossip = true, triggerSelf = true) nodesFut = nodes.mapIt(it.switch.start()) - await allFuturesThrowing(nodes.mapIt(it.start())) await subscribeNodes(nodes) var seen: Table[string, int] @@ -434,7 +372,6 @@ suite "GossipSub": await allFuturesThrowing( nodes.mapIt( allFutures( - it.stop(), it.switch.stop()))) await allFuturesThrowing(nodesFut) @@ -444,24 +381,18 @@ suite "GossipSub": let nodes = generateNodes(2, gossip = true) - # start switches - nodesFut = await allFinished( - nodes[0].switch.start(), - nodes[1].switch.start(), - ) - var gossip = GossipSub(nodes[0]) # MacOs has some nasty jitter when sleeping # (up to 7 ms), so we need some pretty long # sleeps to be safe here gossip.parameters.decayInterval = 300.milliseconds - # start pubsub - await allFuturesThrowing( - allFinished( - nodes[0].start(), - nodes[1].start(), - )) + let + # start switches + nodesFut = await allFinished( + nodes[0].switch.start(), + nodes[1].switch.start(), + ) var handlerFut = newFuture[void]() proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = @@ -489,9 +420,4 @@ suite "GossipSub": nodes[1].switch.stop() ) - await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop() - ) - await allFuturesThrowing(nodesFut.concat()) diff --git a/tests/testinterop.nim b/tests/testinterop.nim index cecace6d0..ce05efc95 100644 --- a/tests/testinterop.nim +++ b/tests/testinterop.nim @@ -55,7 +55,6 @@ proc testPubSubDaemonPublish(gossip: bool = false, count: int = 1) {.async.} = nativeNode.mount(pubsub) await nativeNode.start() - await pubsub.start() let nativePeer = nativeNode.peerInfo var finished = false @@ -89,7 +88,6 @@ proc testPubSubDaemonPublish(gossip: bool = false, count: int = 1) {.async.} = await wait(publisher(), 5.minutes) # should be plenty of time await nativeNode.stop() - await pubsub.stop() await daemonNode.close() proc testPubSubNodePublish(gossip: bool = false, count: int = 1) {.async.} = @@ -115,7 +113,6 @@ proc testPubSubNodePublish(gossip: bool = false, count: int = 1) {.async.} = nativeNode.mount(pubsub) await nativeNode.start() - await pubsub.start() let nativePeer = nativeNode.peerInfo await nativeNode.connect(daemonPeer.peer, daemonPeer.addresses) @@ -149,7 +146,6 @@ proc testPubSubNodePublish(gossip: bool = false, count: int = 1) {.async.} = check finished await nativeNode.stop() - await pubsub.stop() await daemonNode.close() suite "Interop": diff --git a/tests/testswitch.nim b/tests/testswitch.nim index 23cb6ed0b..fc64a21f7 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -979,3 +979,29 @@ suite "Switch": await destSwitch.stop() await srcWsSwitch.stop() await srcTcpSwitch.stop() + + asyncTest "mount unstarted protocol": + proc handle(conn: Connection, proto: string) {.async, gcsafe.} = + check "test123" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("test456") + await conn.close() + let + src = newStandardSwitch() + dst = newStandardSwitch() + testProto = new TestProto + testProto.codec = TestCodec + testProto.handler = handle + + await src.start() + await dst.start() + expect LPError: + dst.mount(testProto) + await testProto.start() + dst.mount(testProto) + + let conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, TestCodec) + await conn.writeLp("test123") + check "test456" == string.fromBytes(await conn.readLp(1024)) + await conn.close() + await src.stop() + await dst.stop()