diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index 2d90c2d..34b7593 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -42,10 +42,17 @@ type ConnEventHandler* = proc(peerId: PeerID, event: ConnEvent): Future[void] {.gcsafe.} - PeerEvent* {.pure.} = enum + PeerEventKind* {.pure.} = enum Left, Joined + PeerEvent* = object + case kind*: PeerEventKind + of PeerEventKind.Joined: + initiator*: bool + else: + discard + PeerEventHandler* = proc(peerId: PeerID, event: PeerEvent): Future[void] {.gcsafe.} @@ -62,7 +69,7 @@ type conns: Table[PeerID, HashSet[Connection]] muxed: Table[Connection, MuxerHolder] connEvents: Table[ConnEventKind, OrderedSet[ConnEventHandler]] - peerEvents: Table[PeerEvent, OrderedSet[PeerEventHandler]] + peerEvents: Table[PeerEventKind, OrderedSet[PeerEventHandler]] proc newTooManyConnections(): ref TooManyConnections {.inline.} = result = newException(TooManyConnections, "too many connections for peer") @@ -77,7 +84,8 @@ proc connCount*(c: ConnManager, peerId: PeerID): int = c.conns.getOrDefault(peerId).len proc addConnEventHandler*(c: ConnManager, - handler: ConnEventHandler, kind: ConnEventKind) = + handler: ConnEventHandler, + kind: ConnEventKind) = ## Add peer event handler - handlers must not raise exceptions! ## @@ -86,11 +94,14 @@ proc addConnEventHandler*(c: ConnManager, initOrderedSet[ConnEventHandler]()).incl(handler) proc removeConnEventHandler*(c: ConnManager, - handler: ConnEventHandler, kind: ConnEventKind) = + handler: ConnEventHandler, + kind: ConnEventKind) = c.connEvents.withValue(kind, handlers) do: handlers[].excl(handler) -proc triggerConnEvent*(c: ConnManager, peerId: PeerID, event: ConnEvent) {.async, gcsafe.} = +proc triggerConnEvent*(c: ConnManager, + peerId: PeerID, + event: ConnEvent) {.async, gcsafe.} = try: if event.kind in c.connEvents: var connEvents: seq[Future[void]] @@ -106,7 +117,7 @@ proc triggerConnEvent*(c: ConnManager, peerId: PeerID, event: ConnEvent) {.async proc addPeerEventHandler*(c: ConnManager, handler: PeerEventHandler, - kind: PeerEvent) = + kind: PeerEventKind) = ## Add peer event handler - handlers must not raise exceptions! ## @@ -116,7 +127,7 @@ proc addPeerEventHandler*(c: ConnManager, proc removePeerEventHandler*(c: ConnManager, handler: PeerEventHandler, - kind: PeerEvent) = + kind: PeerEventKind) = c.peerEvents.withValue(kind, handlers) do: handlers[].excl(handler) @@ -125,22 +136,22 @@ proc triggerPeerEvents*(c: ConnManager, event: PeerEvent) {.async, gcsafe.} = trace "About to trigger peer events", peer = peerId - if event notin c.peerEvents: + if event.kind notin c.peerEvents: return try: let count = c.connCount(peerId) - if event == PeerEvent.Joined and count != 1: + if event.kind == PeerEventKind.Joined and count != 1: trace "peer already joined", peerId, event = $event return - elif event == PeerEvent.Left and count != 0: + elif event.kind == PeerEventKind.Left and count != 0: trace "peer still connected or already left", peerId, event = $event return trace "triggering peer events", peerId, event = $event var peerEvents: seq[Future[void]] - for h in c.peerEvents[event]: + for h in c.peerEvents[event.kind]: peerEvents.add(h(peerId, event)) checkFutures(await allFinished(peerEvents)) @@ -229,7 +240,8 @@ proc peerStartup(c: ConnManager, conn: Connection) {.async.} = try: trace "Triggering connect events", conn let peerId = conn.peerInfo.peerId - await c.triggerPeerEvents(peerId, PeerEvent.Joined) + await c.triggerPeerEvents( + peerId, PeerEvent(kind: PeerEventKind.Joined, initiator: conn.dir == Direction.Out)) await c.triggerConnEvent( peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: conn.dir == Direction.In)) except CatchableError as exc: @@ -244,7 +256,7 @@ proc peerCleanup(c: ConnManager, conn: Connection) {.async.} = let peerId = conn.peerInfo.peerId await c.triggerConnEvent( peerId, ConnEvent(kind: ConnEventKind.Disconnected)) - await c.triggerPeerEvents(peerId, PeerEvent.Left) + await c.triggerPeerEvents(peerId, PeerEvent(kind: PeerEventKind.Left)) except CatchableError as exc: # This is top-level procedure which will work as separate task, so it # do not need to propagate CancelledError and should handle other errors diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index ae9bc34..a747847 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -383,19 +383,18 @@ proc init*[PubParams: object | bool]( parameters: parameters) proc peerEventHandler(peerId: PeerID, event: PeerEvent) {.async.} = - if event == PeerEvent.Joined: + if event.kind == PeerEventKind.Joined: pubsub.subscribePeer(peerId) else: pubsub.unsubscribePeer(peerId) - switch.addPeerEventHandler(peerEventHandler, PeerEvent.Joined) - switch.addPeerEventHandler(peerEventHandler, PeerEvent.Left) + switch.addPeerEventHandler(peerEventHandler, PeerEventKind.Joined) + switch.addPeerEventHandler(peerEventHandler, PeerEventKind.Left) pubsub.initPubSub() return pubsub - proc addObserver*(p: PubSub; observer: PubSubObserver) = p.observers[] &= observer proc removeObserver*(p: PubSub; observer: PubSubObserver) = diff --git a/libp2p/switch.nim b/libp2p/switch.nim index f29e0d6..6e81199 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -74,12 +74,12 @@ proc removeConnEventHandler*(s: Switch, proc addPeerEventHandler*(s: Switch, handler: PeerEventHandler, - kind: PeerEvent) = + kind: PeerEventKind) = s.connManager.addPeerEventHandler(handler, kind) proc removePeerEventHandler*(s: Switch, handler: PeerEventHandler, - kind: PeerEvent) = + kind: PeerEventKind) = s.connManager.removePeerEventHandler(handler, kind) proc disconnect*(s: Switch, peerId: PeerID) {.async, gcsafe.} diff --git a/tests/testswitch.nim b/tests/testswitch.nim index a6fa9bc..6e3ab23 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -359,25 +359,25 @@ suite "Switch": let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Secio]) var step = 0 - var kinds: set[PeerEvent] + var kinds: set[PeerEventKind] proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} = - kinds = kinds + {event} + kinds = kinds + {event.kind} case step: of 0: check: - event == PeerEvent.Joined + event.kind == PeerEventKind.Joined peerId == switch2.peerInfo.peerId of 1: check: - event == PeerEvent.Left + event.kind == PeerEventKind.Left peerId == switch2.peerInfo.peerId else: check false step.inc() - switch1.addPeerEventHandler(handler, PeerEvent.Joined) - switch1.addPeerEventHandler(handler, PeerEvent.Left) + switch1.addPeerEventHandler(handler, PeerEventKind.Joined) + switch1.addPeerEventHandler(handler, PeerEventKind.Left) awaiters.add(await switch1.start()) awaiters.add(await switch2.start()) @@ -398,8 +398,8 @@ suite "Switch": check: kinds == { - PeerEvent.Joined, - PeerEvent.Left + PeerEventKind.Joined, + PeerEventKind.Left } await allFuturesThrowing( @@ -414,25 +414,25 @@ suite "Switch": let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Secio]) var step = 0 - var kinds: set[PeerEvent] + var kinds: set[PeerEventKind] proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} = - kinds = kinds + {event} + kinds = kinds + {event.kind} case step: of 0: check: - event == PeerEvent.Joined + event.kind == PeerEventKind.Joined peerId == switch1.peerInfo.peerId of 1: check: - event == PeerEvent.Left + event.kind == PeerEventKind.Left peerId == switch1.peerInfo.peerId else: check false step.inc() - switch2.addPeerEventHandler(handler, PeerEvent.Joined) - switch2.addPeerEventHandler(handler, PeerEvent.Left) + switch2.addPeerEventHandler(handler, PeerEventKind.Joined) + switch2.addPeerEventHandler(handler, PeerEventKind.Left) awaiters.add(await switch1.start()) awaiters.add(await switch2.start()) @@ -453,8 +453,8 @@ suite "Switch": check: kinds == { - PeerEvent.Joined, - PeerEvent.Left + PeerEventKind.Joined, + PeerEventKind.Left } await allFuturesThrowing( @@ -481,23 +481,23 @@ suite "Switch": secureManagers = [SecureProtocol.Secio]) var step = 0 - var kinds: set[PeerEvent] + var kinds: set[PeerEventKind] proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} = - kinds = kinds + {event} + kinds = kinds + {event.kind} case step: of 0: check: - event == PeerEvent.Joined + event.kind == PeerEventKind.Joined of 1: check: - event == PeerEvent.Left + event.kind == PeerEventKind.Left else: check false # should not trigger this step.inc() - switch1.addPeerEventHandler(handler, PeerEvent.Joined) - switch1.addPeerEventHandler(handler, PeerEvent.Left) + switch1.addPeerEventHandler(handler, PeerEventKind.Joined) + switch1.addPeerEventHandler(handler, PeerEventKind.Left) awaiters.add(await switch1.start()) awaiters.add(await switch2.start()) @@ -523,8 +523,8 @@ suite "Switch": check: kinds == { - PeerEvent.Joined, - PeerEvent.Left + PeerEventKind.Joined, + PeerEventKind.Left } await allFuturesThrowing(