rework peer event to take an initiator flag (#456)

* rework peer event to take an initiator flag

* use correct direction for initiator
This commit is contained in:
Dmitriy Ryajov 2020-11-28 10:59:47 -06:00 committed by GitHub
parent 3d44fcb8b3
commit 18443dafc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 43 deletions

View File

@ -42,10 +42,17 @@ type
ConnEventHandler* = ConnEventHandler* =
proc(peerId: PeerID, event: ConnEvent): Future[void] {.gcsafe.} proc(peerId: PeerID, event: ConnEvent): Future[void] {.gcsafe.}
PeerEvent* {.pure.} = enum PeerEventKind* {.pure.} = enum
Left, Left,
Joined Joined
PeerEvent* = object
case kind*: PeerEventKind
of PeerEventKind.Joined:
initiator*: bool
else:
discard
PeerEventHandler* = PeerEventHandler* =
proc(peerId: PeerID, event: PeerEvent): Future[void] {.gcsafe.} proc(peerId: PeerID, event: PeerEvent): Future[void] {.gcsafe.}
@ -62,7 +69,7 @@ type
conns: Table[PeerID, HashSet[Connection]] conns: Table[PeerID, HashSet[Connection]]
muxed: Table[Connection, MuxerHolder] muxed: Table[Connection, MuxerHolder]
connEvents: Table[ConnEventKind, OrderedSet[ConnEventHandler]] connEvents: Table[ConnEventKind, OrderedSet[ConnEventHandler]]
peerEvents: Table[PeerEvent, OrderedSet[PeerEventHandler]] peerEvents: Table[PeerEventKind, OrderedSet[PeerEventHandler]]
proc newTooManyConnections(): ref TooManyConnections {.inline.} = proc newTooManyConnections(): ref TooManyConnections {.inline.} =
result = newException(TooManyConnections, "too many connections for peer") result = newException(TooManyConnections, "too many connections for peer")
@ -77,7 +84,8 @@ proc connCount*(c: ConnManager, peerId: PeerID): int =
c.conns.getOrDefault(peerId).len c.conns.getOrDefault(peerId).len
proc addConnEventHandler*(c: ConnManager, proc addConnEventHandler*(c: ConnManager,
handler: ConnEventHandler, kind: ConnEventKind) = handler: ConnEventHandler,
kind: ConnEventKind) =
## Add peer event handler - handlers must not raise exceptions! ## Add peer event handler - handlers must not raise exceptions!
## ##
@ -86,11 +94,14 @@ proc addConnEventHandler*(c: ConnManager,
initOrderedSet[ConnEventHandler]()).incl(handler) initOrderedSet[ConnEventHandler]()).incl(handler)
proc removeConnEventHandler*(c: ConnManager, proc removeConnEventHandler*(c: ConnManager,
handler: ConnEventHandler, kind: ConnEventKind) = handler: ConnEventHandler,
kind: ConnEventKind) =
c.connEvents.withValue(kind, handlers) do: c.connEvents.withValue(kind, handlers) do:
handlers[].excl(handler) handlers[].excl(handler)
proc triggerConnEvent*(c: ConnManager, peerId: PeerID, event: ConnEvent) {.async, gcsafe.} = proc triggerConnEvent*(c: ConnManager,
peerId: PeerID,
event: ConnEvent) {.async, gcsafe.} =
try: try:
if event.kind in c.connEvents: if event.kind in c.connEvents:
var connEvents: seq[Future[void]] var connEvents: seq[Future[void]]
@ -106,7 +117,7 @@ proc triggerConnEvent*(c: ConnManager, peerId: PeerID, event: ConnEvent) {.async
proc addPeerEventHandler*(c: ConnManager, proc addPeerEventHandler*(c: ConnManager,
handler: PeerEventHandler, handler: PeerEventHandler,
kind: PeerEvent) = kind: PeerEventKind) =
## Add peer event handler - handlers must not raise exceptions! ## Add peer event handler - handlers must not raise exceptions!
## ##
@ -116,7 +127,7 @@ proc addPeerEventHandler*(c: ConnManager,
proc removePeerEventHandler*(c: ConnManager, proc removePeerEventHandler*(c: ConnManager,
handler: PeerEventHandler, handler: PeerEventHandler,
kind: PeerEvent) = kind: PeerEventKind) =
c.peerEvents.withValue(kind, handlers) do: c.peerEvents.withValue(kind, handlers) do:
handlers[].excl(handler) handlers[].excl(handler)
@ -125,22 +136,22 @@ proc triggerPeerEvents*(c: ConnManager,
event: PeerEvent) {.async, gcsafe.} = event: PeerEvent) {.async, gcsafe.} =
trace "About to trigger peer events", peer = peerId trace "About to trigger peer events", peer = peerId
if event notin c.peerEvents: if event.kind notin c.peerEvents:
return return
try: try:
let count = c.connCount(peerId) let count = c.connCount(peerId)
if event == PeerEvent.Joined and count != 1: if event.kind == PeerEventKind.Joined and count != 1:
trace "peer already joined", peerId, event = $event trace "peer already joined", peerId, event = $event
return return
elif event == PeerEvent.Left and count != 0: elif event.kind == PeerEventKind.Left and count != 0:
trace "peer still connected or already left", peerId, event = $event trace "peer still connected or already left", peerId, event = $event
return return
trace "triggering peer events", peerId, event = $event trace "triggering peer events", peerId, event = $event
var peerEvents: seq[Future[void]] var peerEvents: seq[Future[void]]
for h in c.peerEvents[event]: for h in c.peerEvents[event.kind]:
peerEvents.add(h(peerId, event)) peerEvents.add(h(peerId, event))
checkFutures(await allFinished(peerEvents)) checkFutures(await allFinished(peerEvents))
@ -229,7 +240,8 @@ proc peerStartup(c: ConnManager, conn: Connection) {.async.} =
try: try:
trace "Triggering connect events", conn trace "Triggering connect events", conn
let peerId = conn.peerInfo.peerId let peerId = conn.peerInfo.peerId
await c.triggerPeerEvents(peerId, PeerEvent.Joined) await c.triggerPeerEvents(
peerId, PeerEvent(kind: PeerEventKind.Joined, initiator: conn.dir == Direction.Out))
await c.triggerConnEvent( await c.triggerConnEvent(
peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: conn.dir == Direction.In)) peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: conn.dir == Direction.In))
except CatchableError as exc: except CatchableError as exc:
@ -244,7 +256,7 @@ proc peerCleanup(c: ConnManager, conn: Connection) {.async.} =
let peerId = conn.peerInfo.peerId let peerId = conn.peerInfo.peerId
await c.triggerConnEvent( await c.triggerConnEvent(
peerId, ConnEvent(kind: ConnEventKind.Disconnected)) peerId, ConnEvent(kind: ConnEventKind.Disconnected))
await c.triggerPeerEvents(peerId, PeerEvent.Left) await c.triggerPeerEvents(peerId, PeerEvent(kind: PeerEventKind.Left))
except CatchableError as exc: except CatchableError as exc:
# This is top-level procedure which will work as separate task, so it # This is top-level procedure which will work as separate task, so it
# do not need to propagate CancelledError and should handle other errors # do not need to propagate CancelledError and should handle other errors

View File

@ -383,19 +383,18 @@ proc init*[PubParams: object | bool](
parameters: parameters) parameters: parameters)
proc peerEventHandler(peerId: PeerID, event: PeerEvent) {.async.} = proc peerEventHandler(peerId: PeerID, event: PeerEvent) {.async.} =
if event == PeerEvent.Joined: if event.kind == PeerEventKind.Joined:
pubsub.subscribePeer(peerId) pubsub.subscribePeer(peerId)
else: else:
pubsub.unsubscribePeer(peerId) pubsub.unsubscribePeer(peerId)
switch.addPeerEventHandler(peerEventHandler, PeerEvent.Joined) switch.addPeerEventHandler(peerEventHandler, PeerEventKind.Joined)
switch.addPeerEventHandler(peerEventHandler, PeerEvent.Left) switch.addPeerEventHandler(peerEventHandler, PeerEventKind.Left)
pubsub.initPubSub() pubsub.initPubSub()
return pubsub return pubsub
proc addObserver*(p: PubSub; observer: PubSubObserver) = p.observers[] &= observer proc addObserver*(p: PubSub; observer: PubSubObserver) = p.observers[] &= observer
proc removeObserver*(p: PubSub; observer: PubSubObserver) = proc removeObserver*(p: PubSub; observer: PubSubObserver) =

View File

@ -74,12 +74,12 @@ proc removeConnEventHandler*(s: Switch,
proc addPeerEventHandler*(s: Switch, proc addPeerEventHandler*(s: Switch,
handler: PeerEventHandler, handler: PeerEventHandler,
kind: PeerEvent) = kind: PeerEventKind) =
s.connManager.addPeerEventHandler(handler, kind) s.connManager.addPeerEventHandler(handler, kind)
proc removePeerEventHandler*(s: Switch, proc removePeerEventHandler*(s: Switch,
handler: PeerEventHandler, handler: PeerEventHandler,
kind: PeerEvent) = kind: PeerEventKind) =
s.connManager.removePeerEventHandler(handler, kind) s.connManager.removePeerEventHandler(handler, kind)
proc disconnect*(s: Switch, peerId: PeerID) {.async, gcsafe.} proc disconnect*(s: Switch, peerId: PeerID) {.async, gcsafe.}

View File

@ -359,25 +359,25 @@ suite "Switch":
let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Secio]) let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Secio])
var step = 0 var step = 0
var kinds: set[PeerEvent] var kinds: set[PeerEventKind]
proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} = proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} =
kinds = kinds + {event} kinds = kinds + {event.kind}
case step: case step:
of 0: of 0:
check: check:
event == PeerEvent.Joined event.kind == PeerEventKind.Joined
peerId == switch2.peerInfo.peerId peerId == switch2.peerInfo.peerId
of 1: of 1:
check: check:
event == PeerEvent.Left event.kind == PeerEventKind.Left
peerId == switch2.peerInfo.peerId peerId == switch2.peerInfo.peerId
else: else:
check false check false
step.inc() step.inc()
switch1.addPeerEventHandler(handler, PeerEvent.Joined) switch1.addPeerEventHandler(handler, PeerEventKind.Joined)
switch1.addPeerEventHandler(handler, PeerEvent.Left) switch1.addPeerEventHandler(handler, PeerEventKind.Left)
awaiters.add(await switch1.start()) awaiters.add(await switch1.start())
awaiters.add(await switch2.start()) awaiters.add(await switch2.start())
@ -398,8 +398,8 @@ suite "Switch":
check: check:
kinds == { kinds == {
PeerEvent.Joined, PeerEventKind.Joined,
PeerEvent.Left PeerEventKind.Left
} }
await allFuturesThrowing( await allFuturesThrowing(
@ -414,25 +414,25 @@ suite "Switch":
let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Secio]) let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Secio])
var step = 0 var step = 0
var kinds: set[PeerEvent] var kinds: set[PeerEventKind]
proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} = proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} =
kinds = kinds + {event} kinds = kinds + {event.kind}
case step: case step:
of 0: of 0:
check: check:
event == PeerEvent.Joined event.kind == PeerEventKind.Joined
peerId == switch1.peerInfo.peerId peerId == switch1.peerInfo.peerId
of 1: of 1:
check: check:
event == PeerEvent.Left event.kind == PeerEventKind.Left
peerId == switch1.peerInfo.peerId peerId == switch1.peerInfo.peerId
else: else:
check false check false
step.inc() step.inc()
switch2.addPeerEventHandler(handler, PeerEvent.Joined) switch2.addPeerEventHandler(handler, PeerEventKind.Joined)
switch2.addPeerEventHandler(handler, PeerEvent.Left) switch2.addPeerEventHandler(handler, PeerEventKind.Left)
awaiters.add(await switch1.start()) awaiters.add(await switch1.start())
awaiters.add(await switch2.start()) awaiters.add(await switch2.start())
@ -453,8 +453,8 @@ suite "Switch":
check: check:
kinds == { kinds == {
PeerEvent.Joined, PeerEventKind.Joined,
PeerEvent.Left PeerEventKind.Left
} }
await allFuturesThrowing( await allFuturesThrowing(
@ -481,23 +481,23 @@ suite "Switch":
secureManagers = [SecureProtocol.Secio]) secureManagers = [SecureProtocol.Secio])
var step = 0 var step = 0
var kinds: set[PeerEvent] var kinds: set[PeerEventKind]
proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} = proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} =
kinds = kinds + {event} kinds = kinds + {event.kind}
case step: case step:
of 0: of 0:
check: check:
event == PeerEvent.Joined event.kind == PeerEventKind.Joined
of 1: of 1:
check: check:
event == PeerEvent.Left event.kind == PeerEventKind.Left
else: else:
check false # should not trigger this check false # should not trigger this
step.inc() step.inc()
switch1.addPeerEventHandler(handler, PeerEvent.Joined) switch1.addPeerEventHandler(handler, PeerEventKind.Joined)
switch1.addPeerEventHandler(handler, PeerEvent.Left) switch1.addPeerEventHandler(handler, PeerEventKind.Left)
awaiters.add(await switch1.start()) awaiters.add(await switch1.start())
awaiters.add(await switch2.start()) awaiters.add(await switch2.start())
@ -523,8 +523,8 @@ suite "Switch":
check: check:
kinds == { kinds == {
PeerEvent.Joined, PeerEventKind.Joined,
PeerEvent.Left PeerEventKind.Left
} }
await allFuturesThrowing( await allFuturesThrowing(