diff --git a/libp2p/switch.nim b/libp2p/switch.nim index dc62f6a..eca9691 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -28,7 +28,8 @@ import stream/connection, protocols/pubsub/pubsub, muxers/muxer, connmanager, - peerid + peerid, + errors logScope: topics = "switch" @@ -46,6 +47,13 @@ declareCounter(libp2p_failed_upgrade, "peers failed upgrade") type NoPubSubException* = object of CatchableError + Lifecycle* {.pure.} = enum + Connected, + Upgraded, + Disconnected + + Hook* = proc(peer: PeerInfo, cycle: Lifecycle): Future[void] {.gcsafe.} + Switch* = ref object of RootObj peerInfo*: PeerInfo connManager: ConnManager @@ -58,10 +66,31 @@ type secureManagers*: seq[Secure] pubSub*: Option[PubSub] dialLock: Table[string, AsyncLock] + hooks: Table[Lifecycle, HashSet[Hook]] proc newNoPubSubException(): ref NoPubSubException {.inline.} = result = newException(NoPubSubException, "no pubsub provided!") +proc addHook*(s: Switch, hook: Hook, cycle: Lifecycle) = + s.hooks.mgetOrPut(cycle, initHashSet[Hook]()).incl(hook) + +proc removeHook*(s: Switch, hook: Hook, cycle: Lifecycle) = + s.hooks.mgetOrPut(cycle, initHashSet[Hook]()).excl(hook) + +proc triggerHooks(s: Switch, peer: PeerInfo, cycle: Lifecycle) {.async, gcsafe.} = + try: + if cycle in s.hooks: + var hooks: seq[Future[void]] + for h in s.hooks[cycle]: + if not(isNil(h)): + hooks.add(h(peer, cycle)) + + checkFutures(await allFinished(hooks)) + except CancelledError as exc: + raise exc + except CatchableError as exc: + trace "exception in trigger hooks", exc = exc.msg + proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} @@ -187,9 +216,10 @@ proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, g if isNil(sconn.peerInfo): await sconn.close() raise newException(CatchableError, - "unable to mux connection, stopping upgrade") + "unable to identify connection, stopping upgrade") + + trace "succesfully upgraded outgoing connection", oid = sconn.oid - trace "succesfully upgraded outgoing connection", uoid = sconn.oid return sconn proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = @@ -258,6 +288,13 @@ proc internalConnect(s: Switch, # make sure to assign the peer to the connection conn.peerInfo = peer + conn.closeEvent.wait() + .addCallback do(udata: pointer): + asyncCheck s.triggerHooks( + conn.peerInfo, + Lifecycle.Disconnected) + + asyncCheck s.triggerHooks(conn.peerInfo, Lifecycle.Connected) libp2p_dialed_peers.inc() except CancelledError as exc: trace "dialing canceled", exc = exc.msg @@ -270,7 +307,9 @@ proc internalConnect(s: Switch, try: let uconn = await s.upgradeOutgoing(conn) s.connManager.storeOutgoing(uconn) + asyncCheck s.triggerHooks(uconn.peerInfo, Lifecycle.Upgraded) conn = uconn + trace "dial succesfull", oid = $conn.oid, peer = $conn.peerInfo except CatchableError as exc: if not(isNil(conn)): await conn.close() @@ -283,7 +322,7 @@ proc internalConnect(s: Switch, continue break else: - trace "Reusing existing connection", oid = conn.oid + trace "Reusing existing connection", oid = $conn.oid, direction = conn.dir finally: if lock.locked(): lock.release() @@ -360,7 +399,14 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = try: - conn.dir = Direction.In # tag connection with direction + + conn.closeEvent.wait() + .addCallback do(udata: pointer): + asyncCheck s.triggerHooks( + conn.peerInfo, + Lifecycle.Disconnected) + + asyncCheck s.triggerHooks(conn.peerInfo, Lifecycle.Connected) await s.upgradeIncoming(conn) # perform upgrade on incoming connection except CancelledError as exc: raise exc @@ -437,6 +483,8 @@ proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = proc subscribe*(s: Switch, topic: string, handler: TopicHandler) {.async.} = ## subscribe to a pubsub topic + ## + if s.pubSub.isNone: raise newNoPubSubException() @@ -444,6 +492,8 @@ proc subscribe*(s: Switch, topic: string, proc unsubscribe*(s: Switch, topics: seq[TopicPair]) {.async.} = ## unsubscribe from topics + ## + if s.pubSub.isNone: raise newNoPubSubException() @@ -457,7 +507,9 @@ proc unsubscribeAll*(s: Switch, topic: string) {.async.} = await s.pubSub.get().unsubscribeAll(topic) proc publish*(s: Switch, topic: string, data: seq[byte]): Future[int] {.async.} = - # pubslish to pubsub topic + ## pubslish to pubsub topic + ## + if s.pubSub.isNone: raise newNoPubSubException() @@ -466,7 +518,9 @@ proc publish*(s: Switch, topic: string, data: seq[byte]): Future[int] {.async.} proc addValidator*(s: Switch, topics: varargs[string], hook: ValidatorHandler) = - # add validator + ## add validator + ## + if s.pubSub.isNone: raise newNoPubSubException() @@ -475,7 +529,9 @@ proc addValidator*(s: Switch, proc removeValidator*(s: Switch, topics: varargs[string], hook: ValidatorHandler) = - # pubslish to pubsub topic + ## pubslish to pubsub topic + ## + if s.pubSub.isNone: raise newNoPubSubException() @@ -504,6 +560,7 @@ proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} = s.connManager.storeMuxer(muxer) trace "got new muxer", peer = $muxer.connection.peerInfo + asyncCheck s.triggerHooks(muxer.connection.peerInfo, Lifecycle.Upgraded) # try establishing a pubsub connection await s.subscribePeer(muxer.connection.peerInfo) diff --git a/tests/testswitch.nim b/tests/testswitch.nim index a237ffd..89fe5c4 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -228,3 +228,81 @@ suite "Switch": await allFuturesThrowing(awaiters) waitFor(testSwitch()) + + test "e2e should trigger hooks": + proc testSwitch() {.async, gcsafe.} = + var awaiters: seq[Future[void]] + + let switch1 = newStandardSwitch(secureManagers = [SecureProtocol.Secio]) + let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Secio]) + + var step = 0 + var cycles: set[Lifecycle] + proc hook(peer: PeerInfo, cycle: Lifecycle) {.async, gcsafe.} = + cycles = cycles + {cycle} + case step: + of 0: + check cycle == Lifecycle.Connected + check if not(isNil(peer)): + peer.peerId == switch2.peerInfo.peerId + else: + true + of 1: + assert(isNil(peer) == false) + check: + cycle == Lifecycle.Upgraded + peer.peerId == switch2.peerInfo.peerId + of 2: + check: + cycle == Lifecycle.Disconnected + + check if not(isNil(peer)): + peer.peerId == switch2.peerInfo.peerId + else: + true + else: + echo "unkown cycle! ", $cycle + check false + + step.inc() + + switch1.addHook(hook, Lifecycle.Connected) + switch1.addHook(hook, Lifecycle.Upgraded) + switch1.addHook(hook, Lifecycle.Disconnected) + + awaiters.add(await switch1.start()) + awaiters.add(await switch2.start()) + + await switch2.connect(switch1.peerInfo) + + check switch1.isConnected(switch2.peerInfo) + check switch2.isConnected(switch1.peerInfo) + + await sleepAsync(100.millis) + await switch2.disconnect(switch1.peerInfo) + await sleepAsync(2.seconds) + + check not switch1.isConnected(switch2.peerInfo) + check not switch2.isConnected(switch1.peerInfo) + + var bufferTracker = getTracker(BufferStreamTrackerName) + # echo bufferTracker.dump() + check bufferTracker.isLeaked() == false + + var connTracker = getTracker(ConnectionTrackerName) + # echo connTracker.dump() + check connTracker.isLeaked() == false + + check: + cycles == { + Lifecycle.Connected, + Lifecycle.Upgraded, + Lifecycle.Disconnected + } + + await allFuturesThrowing( + switch1.stop(), + switch2.stop()) + await allFuturesThrowing(awaiters) + + waitFor(testSwitch())