From b0d86b95dd98ce55fc0cf6f9f6ee1f64f5e399ee Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Tue, 15 Sep 2020 14:19:22 -0600 Subject: [PATCH] add peer lifecycle events (#357) * add peer lifecycle events * rework peer events to not use connection events * don't use result in pubsub and switch init * wip * use ordered hashes and remove logscope * logging * add missing test * small fixes --- libp2p/connmanager.nim | 9 +- libp2p/protocols/pubsub/pubsub.nim | 31 +++- libp2p/switch.nim | 97 ++++++++-- tests/pubsub/utils.nim | 3 - tests/testinterop.nim | 2 - tests/testswitch.nim | 277 ++++++++++++++++++++++++++++- 6 files changed, 384 insertions(+), 35 deletions(-) diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index 1118a4c62..fe09df220 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -30,7 +30,7 @@ type ConnManager* = ref object of RootObj # NOTE: don't change to PeerInfo here # the reference semantics on the PeerInfo - # object itself make it succeptible to + # object itself make it susceptible to # copies and mangling by unrelated code. conns: Table[PeerID, HashSet[Connection]] muxed: Table[Connection, MuxerHolder] @@ -78,6 +78,9 @@ proc contains*(c: ConnManager, muxer: Muxer): bool = return muxer == c.muxed[conn].muxer +proc connCount*(c: ConnManager, peerId: PeerID): int = + c.conns.getOrDefault(peerId).len + proc closeMuxerHolder(muxerHolder: MuxerHolder) {.async.} = trace "Cleaning up muxer", m = muxerHolder.muxer @@ -267,6 +270,7 @@ proc dropPeer*(c: ConnManager, peerId: PeerID) {.async.} = trace "Dropping peer", peerId let conns = c.conns.getOrDefault(peerId) for conn in conns: + trace "Removing connection", conn delConn(c, conn) var muxers: seq[MuxerHolder] @@ -280,8 +284,7 @@ proc dropPeer*(c: ConnManager, peerId: PeerID) {.async.} = for conn in conns: await conn.close() - - trace "Dropped peer", peerId + trace "Dropped peer", peerId proc close*(c: ConnManager) {.async.} = ## cleanup resources for the connection diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index f89598ef1..2ef091481 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -309,15 +309,28 @@ proc init*( verifySignature: bool = true, sign: bool = true, msgIdProvider: MsgIdProvider = defaultMsgIdProvider): P = - result = P(switch: switch, - peerInfo: switch.peerInfo, - triggerSelf: triggerSelf, - verifySignature: verifySignature, - sign: sign, - peers: initTable[PeerID, PubSubPeer](), - topics: initTable[string, Topic](), - msgIdProvider: msgIdProvider) - result.initPubSub() + + let pubsub = P( + switch: switch, + peerInfo: switch.peerInfo, + triggerSelf: triggerSelf, + verifySignature: verifySignature, + sign: sign, + peers: initTable[PeerID, PubSubPeer](), + topics: initTable[string, Topic](), + msgIdProvider: msgIdProvider) + + proc peerEventHandler(peerId: PeerID, event: PeerEvent) {.async.} = + if event == PeerEvent.Joined: + pubsub.subscribePeer(peerId) + else: + pubsub.unsubscribePeer(peerId) + + switch.addPeerEventHandler(peerEventHandler, PeerEvent.Joined) + switch.addPeerEventHandler(peerEventHandler, PeerEvent.Left) + + pubsub.initPubSub() + return pubsub proc addObserver*(p: PubSub; observer: PubSubObserver) = p.observers[] &= observer diff --git a/libp2p/switch.nim b/libp2p/switch.nim index c8bf3c2c8..20f579afc 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -64,6 +64,13 @@ type ConnEventHandler* = proc(peerId: PeerID, event: ConnEvent): Future[void] {.gcsafe.} + PeerEvent* {.pure.} = enum + Left, + Joined + + PeerEventHandler* = + proc(peerId: PeerID, event: PeerEvent): Future[void] {.gcsafe.} + Switch* = ref object of RootObj peerInfo*: PeerInfo connManager: ConnManager @@ -75,33 +82,86 @@ type streamHandler*: StreamHandler secureManagers*: seq[Secure] dialLock: Table[PeerID, AsyncLock] - ConnEvents: Table[ConnEventKind, HashSet[ConnEventHandler]] + connEvents: Table[ConnEventKind, OrderedSet[ConnEventHandler]] + peerEvents: Table[PeerEvent, OrderedSet[PeerEventHandler]] proc addConnEventHandler*(s: Switch, handler: ConnEventHandler, kind: ConnEventKind) = ## Add peer event handler - handlers must not raise exceptions! + ## + if isNil(handler): return - s.ConnEvents.mgetOrPut(kind, initHashSet[ConnEventHandler]()).incl(handler) + s.connEvents.mgetOrPut(kind, + initOrderedSet[ConnEventHandler]()).incl(handler) proc removeConnEventHandler*(s: Switch, handler: ConnEventHandler, kind: ConnEventKind) = - s.ConnEvents.withValue(kind, handlers) do: + s.connEvents.withValue(kind, handlers) do: handlers[].excl(handler) proc triggerConnEvent(s: Switch, peerId: PeerID, event: ConnEvent) {.async, gcsafe.} = try: - if event.kind in s.ConnEvents: - var ConnEvents: seq[Future[void]] - for h in s.ConnEvents[event.kind]: - ConnEvents.add(h(peerId, event)) + if event.kind in s.connEvents: + var connEvents: seq[Future[void]] + for h in s.connEvents[event.kind]: + connEvents.add(h(peerId, event)) - checkFutures(await allFinished(ConnEvents)) + checkFutures(await allFinished(connEvents)) except CancelledError as exc: raise exc except CatchableError as exc: # handlers should not raise! warn "Exception in triggerConnEvents", msg = exc.msg, peerId, event = $event +proc addPeerEventHandler*(s: Switch, + handler: PeerEventHandler, + kind: PeerEvent) = + ## Add peer event handler - handlers must not raise exceptions! + ## + + if isNil(handler): return + s.peerEvents.mgetOrPut(kind, + initOrderedSet[PeerEventHandler]()).incl(handler) + +proc removePeerEventHandler*(s: Switch, + handler: PeerEventHandler, + kind: PeerEvent) = + s.peerEvents.withValue(kind, handlers) do: + handlers[].excl(handler) + +proc triggerPeerEvents(s: Switch, + peerId: PeerID, + event: PeerEvent) {.async, gcsafe.} = + + if event notin s.peerEvents: + return + + try: + let count = s.connManager.connCount(peerId) + if event == PeerEvent.Joined and count != 1: + trace "peer already joined", local = s.peerInfo.peerId, + remote = peerId, event + return + elif event == PeerEvent.Left and count != 0: + trace "peer still connected or already left", local = s.peerInfo.peerId, + remote = peerId, event + return + + trace "triggering peer events", local = s.peerInfo.peerId, + remote = peerId, event + + var peerEvents: seq[Future[void]] + for h in s.peerEvents[event]: + peerEvents.add(h(peerId, event)) + + checkFutures(await allFinished(peerEvents)) + except CancelledError as exc: + raise exc + except CatchableError as exc: # handlers should not raise! + warn "exception in triggerPeerEvents", exc = exc.msg, peerId + +proc disconnect*(s: Switch, peerId: PeerID) {.async, gcsafe.} + proc isConnected*(s: Switch, peerId: PeerID): bool = ## returns true if the peer has one or more ## associated connections (sockets) @@ -352,6 +412,7 @@ proc internalConnect(s: Switch, # unworthy and disconnects it raise newLPStreamClosedError() + await s.triggerPeerEvents(peerId, PeerEvent.Joined) await s.triggerConnEvent( peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: false)) @@ -360,6 +421,7 @@ proc internalConnect(s: Switch, await conn.closeEvent.wait() await s.triggerConnEvent(peerId, ConnEvent(kind: ConnEventKind.Disconnected)) + await s.triggerPeerEvents(peerId, PeerEvent.Left) except CatchableError as exc: # This is top-level procedure which will work as separate task, so it # do not need to propogate CancelledError and should handle other errors @@ -505,9 +567,10 @@ proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} = proc peerCleanup() {.async.} = try: - await muxer.connection.closeEvent.wait() + await muxer.connection.join() await s.triggerConnEvent(peerId, ConnEvent(kind: ConnEventKind.Disconnected)) + await s.triggerPeerEvents(peerId, PeerEvent.Left) except CatchableError as exc: # This is top-level procedure which will work as separate task, so it # do not need to propogate CancelledError and shouldn't leak others @@ -516,6 +579,7 @@ proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} = proc peerStartup() {.async.} = try: + await s.triggerPeerEvents(peerId, PeerEvent.Joined) await s.triggerConnEvent(peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: true)) @@ -547,7 +611,7 @@ proc newSwitch*(peerInfo: PeerInfo, if secureManagers.len == 0: raise (ref CatchableError)(msg: "Provide at least one secure manager") - result = Switch( + let switch = Switch( peerInfo: peerInfo, ms: newMultistream(), transports: transports, @@ -557,11 +621,10 @@ proc newSwitch*(peerInfo: PeerInfo, secureManagers: @secureManagers, ) - let s = result # can't capture result - result.streamHandler = proc(conn: Connection) {.async, gcsafe.} = # noraises + switch.streamHandler = proc(conn: Connection) {.async, gcsafe.} = # noraises trace "Starting stream handler", conn try: - await s.ms.handle(conn) # handle incoming connection + await switch.ms.handle(conn) # handle incoming connection except CancelledError as exc: raise exc except CatchableError as exc: @@ -570,11 +633,13 @@ proc newSwitch*(peerInfo: PeerInfo, await conn.close() trace "Stream handler done", conn - result.mount(identity) + switch.mount(identity) for key, val in muxers: - val.streamHandler = result.streamHandler + val.streamHandler = switch.streamHandler val.muxerHandler = proc(muxer: Muxer): Future[void] = - s.muxerHandler(muxer) + switch.muxerHandler(muxer) + + return switch proc isConnected*(s: Switch, peerInfo: PeerInfo): bool {.deprecated: "Use PeerID version".} = diff --git a/tests/pubsub/utils.nim b/tests/pubsub/utils.nim index 72b437ed6..0c0a1d7f6 100644 --- a/tests/pubsub/utils.nim +++ b/tests/pubsub/utils.nim @@ -53,7 +53,6 @@ proc subscribeNodes*(nodes: seq[PubSub]) {.async.} = for node in nodes: if dialer.switch.peerInfo.peerId != node.switch.peerInfo.peerId: await dialer.switch.connect(node.peerInfo.peerId, node.peerInfo.addrs) - dialer.subscribePeer(node.peerInfo.peerId) proc subscribeSparseNodes*(nodes: seq[PubSub], degree: int = 2) {.async.} = if nodes.len < degree: @@ -66,7 +65,6 @@ proc subscribeSparseNodes*(nodes: seq[PubSub], degree: int = 2) {.async.} = for node in nodes: if dialer.switch.peerInfo.peerId != node.peerInfo.peerId: await dialer.switch.connect(node.peerInfo.peerId, node.peerInfo.addrs) - dialer.subscribePeer(node.peerInfo.peerId) proc subscribeRandom*(nodes: seq[PubSub]) {.async.} = for dialer in nodes: @@ -76,5 +74,4 @@ proc subscribeRandom*(nodes: seq[PubSub]) {.async.} = if node.peerInfo.peerId notin dialed: if dialer.peerInfo.peerId != node.peerInfo.peerId: await dialer.switch.connect(node.peerInfo.peerId, node.peerInfo.addrs) - dialer.subscribePeer(node.peerInfo.peerId) dialed.add(node.peerInfo.peerId) diff --git a/tests/testinterop.nim b/tests/testinterop.nim index 1573a658a..c2fb9656f 100644 --- a/tests/testinterop.nim +++ b/tests/testinterop.nim @@ -100,7 +100,6 @@ proc testPubSubDaemonPublish(gossip: bool = false, daemonPeer.peer, daemonPeer.addresses) await nativeNode.connect(peer.peerId, peer.addrs) - pubsub.subscribePeer(peer.peerId) await sleepAsync(1.seconds) await daemonNode.connect(nativePeer.peerId, nativePeer.addrs) @@ -160,7 +159,6 @@ proc testPubSubNodePublish(gossip: bool = false, daemonPeer.peer, daemonPeer.addresses) await nativeNode.connect(peer) - pubsub.subscribePeer(peer.peerId) await sleepAsync(1.seconds) await daemonNode.connect(nativePeer.peerId, nativePeer.addrs) diff --git a/tests/testswitch.nim b/tests/testswitch.nim index fe2391a64..fd1645807 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -1,6 +1,6 @@ {.used.} -import unittest +import unittest, options import chronos import stew/byteutils import nimcrypto/sysrand @@ -229,7 +229,72 @@ suite "Switch": waitFor(testSwitch()) - test "e2e should trigger hooks": + test "e2e should trigger connection events (remote)": + proc testSwitch() {.async, gcsafe.} = + var awaiters: seq[Future[void]] + + let switch1 = newStandardSwitch(secureManagers = [SecureProtocol.Secio]) + let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Secio]) + + var step = 0 + var kinds: set[ConnEventKind] + proc hook(peerId: PeerID, event: ConnEvent) {.async, gcsafe.} = + kinds = kinds + {event.kind} + case step: + of 0: + check: + event.kind == ConnEventKind.Connected + peerId == switch1.peerInfo.peerId + of 1: + check: + event.kind == ConnEventKind.Disconnected + + check peerId == switch1.peerInfo.peerId + else: + check false + + step.inc() + + switch2.addConnEventHandler(hook, ConnEventKind.Connected) + switch2.addConnEventHandler(hook, ConnEventKind.Disconnected) + + awaiters.add(await switch1.start()) + awaiters.add(await switch2.start()) + + await switch2.connect(switch1.peerInfo) + + check switch1.isConnected(switch2.peerInfo) + check switch2.isConnected(switch1.peerInfo) + + await sleepAsync(100.millis) + await switch2.disconnect(switch1.peerInfo) + await sleepAsync(2.seconds) + + check not switch1.isConnected(switch2.peerInfo) + check not switch2.isConnected(switch1.peerInfo) + + var bufferTracker = getTracker(BufferStreamTrackerName) + # echo bufferTracker.dump() + check bufferTracker.isLeaked() == false + + var connTracker = getTracker(ConnectionTrackerName) + # echo connTracker.dump() + check connTracker.isLeaked() == false + + check: + kinds == { + ConnEventKind.Connected, + ConnEventKind.Disconnected + } + + await allFuturesThrowing( + switch1.stop(), + switch2.stop()) + await allFuturesThrowing(awaiters) + + waitFor(testSwitch()) + + test "e2e should trigger connection events (local)": proc testSwitch() {.async, gcsafe.} = var awaiters: seq[Future[void]] @@ -293,3 +358,211 @@ suite "Switch": await allFuturesThrowing(awaiters) waitFor(testSwitch()) + + test "e2e should trigger peer events (remote)": + proc testSwitch() {.async, gcsafe.} = + var awaiters: seq[Future[void]] + + let switch1 = newStandardSwitch(secureManagers = [SecureProtocol.Secio]) + let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Secio]) + + var step = 0 + var kinds: set[PeerEvent] + proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} = + kinds = kinds + {event} + case step: + of 0: + check: + event == PeerEvent.Joined + peerId == switch2.peerInfo.peerId + of 1: + check: + event == PeerEvent.Left + peerId == switch2.peerInfo.peerId + else: + check false + + step.inc() + + switch1.addPeerEventHandler(handler, PeerEvent.Joined) + switch1.addPeerEventHandler(handler, PeerEvent.Left) + + awaiters.add(await switch1.start()) + awaiters.add(await switch2.start()) + + await switch2.connect(switch1.peerInfo) + + check switch1.isConnected(switch2.peerInfo) + check switch2.isConnected(switch1.peerInfo) + + await sleepAsync(100.millis) + await switch2.disconnect(switch1.peerInfo) + await sleepAsync(2.seconds) + + check not switch1.isConnected(switch2.peerInfo) + check not switch2.isConnected(switch1.peerInfo) + + var bufferTracker = getTracker(BufferStreamTrackerName) + # echo bufferTracker.dump() + check bufferTracker.isLeaked() == false + + var connTracker = getTracker(ConnectionTrackerName) + # echo connTracker.dump() + check connTracker.isLeaked() == false + + check: + kinds == { + PeerEvent.Joined, + PeerEvent.Left + } + + await allFuturesThrowing( + switch1.stop(), + switch2.stop()) + await allFuturesThrowing(awaiters) + + waitFor(testSwitch()) + + test "e2e should trigger peer events (local)": + proc testSwitch() {.async, gcsafe.} = + var awaiters: seq[Future[void]] + + let switch1 = newStandardSwitch(secureManagers = [SecureProtocol.Secio]) + let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Secio]) + + var step = 0 + var kinds: set[PeerEvent] + proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} = + kinds = kinds + {event} + case step: + of 0: + check: + event == PeerEvent.Joined + peerId == switch1.peerInfo.peerId + of 1: + check: + event == PeerEvent.Left + peerId == switch1.peerInfo.peerId + else: + check false + + step.inc() + + switch2.addPeerEventHandler(handler, PeerEvent.Joined) + switch2.addPeerEventHandler(handler, PeerEvent.Left) + + awaiters.add(await switch1.start()) + awaiters.add(await switch2.start()) + + await switch2.connect(switch1.peerInfo) + + check switch1.isConnected(switch2.peerInfo) + check switch2.isConnected(switch1.peerInfo) + + await sleepAsync(100.millis) + await switch2.disconnect(switch1.peerInfo) + await sleepAsync(2.seconds) + + check not switch1.isConnected(switch2.peerInfo) + check not switch2.isConnected(switch1.peerInfo) + + var bufferTracker = getTracker(BufferStreamTrackerName) + # echo bufferTracker.dump() + check bufferTracker.isLeaked() == false + + var connTracker = getTracker(ConnectionTrackerName) + # echo connTracker.dump() + check connTracker.isLeaked() == false + + check: + kinds == { + PeerEvent.Joined, + PeerEvent.Left + } + + await allFuturesThrowing( + switch1.stop(), + switch2.stop()) + await allFuturesThrowing(awaiters) + + waitFor(testSwitch()) + + test "e2e should trigger peer events only once per peer": + proc testSwitch() {.async, gcsafe.} = + var awaiters: seq[Future[void]] + + let switch1 = newStandardSwitch(secureManagers = [SecureProtocol.Secio]) + + let rng = newRng() + # use same private keys to emulate two connection from same peer + let privKey = PrivateKey.random(rng[]).tryGet() + let switch2 = newStandardSwitch( + privKey = some(privKey), + rng = rng, + secureManagers = [SecureProtocol.Secio]) + + let switch3 = newStandardSwitch( + privKey = some(privKey), + rng = rng, + secureManagers = [SecureProtocol.Secio]) + + var step = 0 + var kinds: set[PeerEvent] + proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} = + kinds = kinds + {event} + case step: + of 0: + check: + event == PeerEvent.Joined + of 1: + check: + event == PeerEvent.Left + else: + check false # should not trigger this + + step.inc() + + switch1.addPeerEventHandler(handler, PeerEvent.Joined) + switch1.addPeerEventHandler(handler, PeerEvent.Left) + + awaiters.add(await switch1.start()) + awaiters.add(await switch2.start()) + awaiters.add(await switch3.start()) + + await switch2.connect(switch1.peerInfo) # should trigger 1st Join event + await switch3.connect(switch1.peerInfo) # should trigger 2nd Join event + + check switch1.isConnected(switch2.peerInfo) + check switch2.isConnected(switch1.peerInfo) + check switch3.isConnected(switch1.peerInfo) + + await sleepAsync(100.millis) + await switch2.disconnect(switch1.peerInfo) # should trigger 1st Left event + await switch3.disconnect(switch1.peerInfo) # should trigger 2nd Left event + await sleepAsync(2.seconds) + + check not switch1.isConnected(switch2.peerInfo) + check not switch2.isConnected(switch1.peerInfo) + check not switch3.isConnected(switch1.peerInfo) + + var bufferTracker = getTracker(BufferStreamTrackerName) + # echo bufferTracker.dump() + check bufferTracker.isLeaked() == false + + var connTracker = getTracker(ConnectionTrackerName) + # echo connTracker.dump() + check connTracker.isLeaked() == false + + check: + kinds == { + PeerEvent.Joined, + PeerEvent.Left + } + + await allFuturesThrowing( + switch1.stop(), + switch2.stop(), + switch3.stop()) + await allFuturesThrowing(awaiters) + + waitFor(testSwitch())