Connection lifecycle hooks (#288)
* lifecycle hooks * trigger hooks as tasks * handle exceptions in trigger hooks * trigger hooks after storing the connection * add disconnected hook * tests
This commit is contained in:
parent
6af3cb6406
commit
ed0df74bbd
|
@ -28,7 +28,8 @@ import stream/connection,
|
||||||
protocols/pubsub/pubsub,
|
protocols/pubsub/pubsub,
|
||||||
muxers/muxer,
|
muxers/muxer,
|
||||||
connmanager,
|
connmanager,
|
||||||
peerid
|
peerid,
|
||||||
|
errors
|
||||||
|
|
||||||
logScope:
|
logScope:
|
||||||
topics = "switch"
|
topics = "switch"
|
||||||
|
@ -46,6 +47,13 @@ declareCounter(libp2p_failed_upgrade, "peers failed upgrade")
|
||||||
type
|
type
|
||||||
NoPubSubException* = object of CatchableError
|
NoPubSubException* = object of CatchableError
|
||||||
|
|
||||||
|
Lifecycle* {.pure.} = enum
|
||||||
|
Connected,
|
||||||
|
Upgraded,
|
||||||
|
Disconnected
|
||||||
|
|
||||||
|
Hook* = proc(peer: PeerInfo, cycle: Lifecycle): Future[void] {.gcsafe.}
|
||||||
|
|
||||||
Switch* = ref object of RootObj
|
Switch* = ref object of RootObj
|
||||||
peerInfo*: PeerInfo
|
peerInfo*: PeerInfo
|
||||||
connManager: ConnManager
|
connManager: ConnManager
|
||||||
|
@ -58,10 +66,31 @@ type
|
||||||
secureManagers*: seq[Secure]
|
secureManagers*: seq[Secure]
|
||||||
pubSub*: Option[PubSub]
|
pubSub*: Option[PubSub]
|
||||||
dialLock: Table[string, AsyncLock]
|
dialLock: Table[string, AsyncLock]
|
||||||
|
hooks: Table[Lifecycle, HashSet[Hook]]
|
||||||
|
|
||||||
proc newNoPubSubException(): ref NoPubSubException {.inline.} =
|
proc newNoPubSubException(): ref NoPubSubException {.inline.} =
|
||||||
result = newException(NoPubSubException, "no pubsub provided!")
|
result = newException(NoPubSubException, "no pubsub provided!")
|
||||||
|
|
||||||
|
proc addHook*(s: Switch, hook: Hook, cycle: Lifecycle) =
|
||||||
|
s.hooks.mgetOrPut(cycle, initHashSet[Hook]()).incl(hook)
|
||||||
|
|
||||||
|
proc removeHook*(s: Switch, hook: Hook, cycle: Lifecycle) =
|
||||||
|
s.hooks.mgetOrPut(cycle, initHashSet[Hook]()).excl(hook)
|
||||||
|
|
||||||
|
proc triggerHooks(s: Switch, peer: PeerInfo, cycle: Lifecycle) {.async, gcsafe.} =
|
||||||
|
try:
|
||||||
|
if cycle in s.hooks:
|
||||||
|
var hooks: seq[Future[void]]
|
||||||
|
for h in s.hooks[cycle]:
|
||||||
|
if not(isNil(h)):
|
||||||
|
hooks.add(h(peer, cycle))
|
||||||
|
|
||||||
|
checkFutures(await allFinished(hooks))
|
||||||
|
except CancelledError as exc:
|
||||||
|
raise exc
|
||||||
|
except CatchableError as exc:
|
||||||
|
trace "exception in trigger hooks", exc = exc.msg
|
||||||
|
|
||||||
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.}
|
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.}
|
||||||
proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.}
|
proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.}
|
||||||
|
|
||||||
|
@ -187,9 +216,10 @@ proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, g
|
||||||
if isNil(sconn.peerInfo):
|
if isNil(sconn.peerInfo):
|
||||||
await sconn.close()
|
await sconn.close()
|
||||||
raise newException(CatchableError,
|
raise newException(CatchableError,
|
||||||
"unable to mux connection, stopping upgrade")
|
"unable to identify connection, stopping upgrade")
|
||||||
|
|
||||||
|
trace "succesfully upgraded outgoing connection", oid = sconn.oid
|
||||||
|
|
||||||
trace "succesfully upgraded outgoing connection", uoid = sconn.oid
|
|
||||||
return sconn
|
return sconn
|
||||||
|
|
||||||
proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
|
proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
|
||||||
|
@ -258,6 +288,13 @@ proc internalConnect(s: Switch,
|
||||||
# make sure to assign the peer to the connection
|
# make sure to assign the peer to the connection
|
||||||
conn.peerInfo = peer
|
conn.peerInfo = peer
|
||||||
|
|
||||||
|
conn.closeEvent.wait()
|
||||||
|
.addCallback do(udata: pointer):
|
||||||
|
asyncCheck s.triggerHooks(
|
||||||
|
conn.peerInfo,
|
||||||
|
Lifecycle.Disconnected)
|
||||||
|
|
||||||
|
asyncCheck s.triggerHooks(conn.peerInfo, Lifecycle.Connected)
|
||||||
libp2p_dialed_peers.inc()
|
libp2p_dialed_peers.inc()
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
trace "dialing canceled", exc = exc.msg
|
trace "dialing canceled", exc = exc.msg
|
||||||
|
@ -270,7 +307,9 @@ proc internalConnect(s: Switch,
|
||||||
try:
|
try:
|
||||||
let uconn = await s.upgradeOutgoing(conn)
|
let uconn = await s.upgradeOutgoing(conn)
|
||||||
s.connManager.storeOutgoing(uconn)
|
s.connManager.storeOutgoing(uconn)
|
||||||
|
asyncCheck s.triggerHooks(uconn.peerInfo, Lifecycle.Upgraded)
|
||||||
conn = uconn
|
conn = uconn
|
||||||
|
trace "dial succesfull", oid = $conn.oid, peer = $conn.peerInfo
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
if not(isNil(conn)):
|
if not(isNil(conn)):
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
@ -283,7 +322,7 @@ proc internalConnect(s: Switch,
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
trace "Reusing existing connection", oid = conn.oid
|
trace "Reusing existing connection", oid = $conn.oid, direction = conn.dir
|
||||||
finally:
|
finally:
|
||||||
if lock.locked():
|
if lock.locked():
|
||||||
lock.release()
|
lock.release()
|
||||||
|
@ -360,7 +399,14 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
|
||||||
|
|
||||||
proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} =
|
proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} =
|
||||||
try:
|
try:
|
||||||
conn.dir = Direction.In # tag connection with direction
|
|
||||||
|
conn.closeEvent.wait()
|
||||||
|
.addCallback do(udata: pointer):
|
||||||
|
asyncCheck s.triggerHooks(
|
||||||
|
conn.peerInfo,
|
||||||
|
Lifecycle.Disconnected)
|
||||||
|
|
||||||
|
asyncCheck s.triggerHooks(conn.peerInfo, Lifecycle.Connected)
|
||||||
await s.upgradeIncoming(conn) # perform upgrade on incoming connection
|
await s.upgradeIncoming(conn) # perform upgrade on incoming connection
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
raise exc
|
raise exc
|
||||||
|
@ -437,6 +483,8 @@ proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =
|
||||||
proc subscribe*(s: Switch, topic: string,
|
proc subscribe*(s: Switch, topic: string,
|
||||||
handler: TopicHandler) {.async.} =
|
handler: TopicHandler) {.async.} =
|
||||||
## subscribe to a pubsub topic
|
## subscribe to a pubsub topic
|
||||||
|
##
|
||||||
|
|
||||||
if s.pubSub.isNone:
|
if s.pubSub.isNone:
|
||||||
raise newNoPubSubException()
|
raise newNoPubSubException()
|
||||||
|
|
||||||
|
@ -444,6 +492,8 @@ proc subscribe*(s: Switch, topic: string,
|
||||||
|
|
||||||
proc unsubscribe*(s: Switch, topics: seq[TopicPair]) {.async.} =
|
proc unsubscribe*(s: Switch, topics: seq[TopicPair]) {.async.} =
|
||||||
## unsubscribe from topics
|
## unsubscribe from topics
|
||||||
|
##
|
||||||
|
|
||||||
if s.pubSub.isNone:
|
if s.pubSub.isNone:
|
||||||
raise newNoPubSubException()
|
raise newNoPubSubException()
|
||||||
|
|
||||||
|
@ -457,7 +507,9 @@ proc unsubscribeAll*(s: Switch, topic: string) {.async.} =
|
||||||
await s.pubSub.get().unsubscribeAll(topic)
|
await s.pubSub.get().unsubscribeAll(topic)
|
||||||
|
|
||||||
proc publish*(s: Switch, topic: string, data: seq[byte]): Future[int] {.async.} =
|
proc publish*(s: Switch, topic: string, data: seq[byte]): Future[int] {.async.} =
|
||||||
# pubslish to pubsub topic
|
## pubslish to pubsub topic
|
||||||
|
##
|
||||||
|
|
||||||
if s.pubSub.isNone:
|
if s.pubSub.isNone:
|
||||||
raise newNoPubSubException()
|
raise newNoPubSubException()
|
||||||
|
|
||||||
|
@ -466,7 +518,9 @@ proc publish*(s: Switch, topic: string, data: seq[byte]): Future[int] {.async.}
|
||||||
proc addValidator*(s: Switch,
|
proc addValidator*(s: Switch,
|
||||||
topics: varargs[string],
|
topics: varargs[string],
|
||||||
hook: ValidatorHandler) =
|
hook: ValidatorHandler) =
|
||||||
# add validator
|
## add validator
|
||||||
|
##
|
||||||
|
|
||||||
if s.pubSub.isNone:
|
if s.pubSub.isNone:
|
||||||
raise newNoPubSubException()
|
raise newNoPubSubException()
|
||||||
|
|
||||||
|
@ -475,7 +529,9 @@ proc addValidator*(s: Switch,
|
||||||
proc removeValidator*(s: Switch,
|
proc removeValidator*(s: Switch,
|
||||||
topics: varargs[string],
|
topics: varargs[string],
|
||||||
hook: ValidatorHandler) =
|
hook: ValidatorHandler) =
|
||||||
# pubslish to pubsub topic
|
## pubslish to pubsub topic
|
||||||
|
##
|
||||||
|
|
||||||
if s.pubSub.isNone:
|
if s.pubSub.isNone:
|
||||||
raise newNoPubSubException()
|
raise newNoPubSubException()
|
||||||
|
|
||||||
|
@ -504,6 +560,7 @@ proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} =
|
||||||
s.connManager.storeMuxer(muxer)
|
s.connManager.storeMuxer(muxer)
|
||||||
|
|
||||||
trace "got new muxer", peer = $muxer.connection.peerInfo
|
trace "got new muxer", peer = $muxer.connection.peerInfo
|
||||||
|
asyncCheck s.triggerHooks(muxer.connection.peerInfo, Lifecycle.Upgraded)
|
||||||
|
|
||||||
# try establishing a pubsub connection
|
# try establishing a pubsub connection
|
||||||
await s.subscribePeer(muxer.connection.peerInfo)
|
await s.subscribePeer(muxer.connection.peerInfo)
|
||||||
|
|
|
@ -228,3 +228,81 @@ suite "Switch":
|
||||||
await allFuturesThrowing(awaiters)
|
await allFuturesThrowing(awaiters)
|
||||||
|
|
||||||
waitFor(testSwitch())
|
waitFor(testSwitch())
|
||||||
|
|
||||||
|
test "e2e should trigger hooks":
|
||||||
|
proc testSwitch() {.async, gcsafe.} =
|
||||||
|
var awaiters: seq[Future[void]]
|
||||||
|
|
||||||
|
let switch1 = newStandardSwitch(secureManagers = [SecureProtocol.Secio])
|
||||||
|
let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Secio])
|
||||||
|
|
||||||
|
var step = 0
|
||||||
|
var cycles: set[Lifecycle]
|
||||||
|
proc hook(peer: PeerInfo, cycle: Lifecycle) {.async, gcsafe.} =
|
||||||
|
cycles = cycles + {cycle}
|
||||||
|
case step:
|
||||||
|
of 0:
|
||||||
|
check cycle == Lifecycle.Connected
|
||||||
|
check if not(isNil(peer)):
|
||||||
|
peer.peerId == switch2.peerInfo.peerId
|
||||||
|
else:
|
||||||
|
true
|
||||||
|
of 1:
|
||||||
|
assert(isNil(peer) == false)
|
||||||
|
check:
|
||||||
|
cycle == Lifecycle.Upgraded
|
||||||
|
peer.peerId == switch2.peerInfo.peerId
|
||||||
|
of 2:
|
||||||
|
check:
|
||||||
|
cycle == Lifecycle.Disconnected
|
||||||
|
|
||||||
|
check if not(isNil(peer)):
|
||||||
|
peer.peerId == switch2.peerInfo.peerId
|
||||||
|
else:
|
||||||
|
true
|
||||||
|
else:
|
||||||
|
echo "unkown cycle! ", $cycle
|
||||||
|
check false
|
||||||
|
|
||||||
|
step.inc()
|
||||||
|
|
||||||
|
switch1.addHook(hook, Lifecycle.Connected)
|
||||||
|
switch1.addHook(hook, Lifecycle.Upgraded)
|
||||||
|
switch1.addHook(hook, Lifecycle.Disconnected)
|
||||||
|
|
||||||
|
awaiters.add(await switch1.start())
|
||||||
|
awaiters.add(await switch2.start())
|
||||||
|
|
||||||
|
await switch2.connect(switch1.peerInfo)
|
||||||
|
|
||||||
|
check switch1.isConnected(switch2.peerInfo)
|
||||||
|
check switch2.isConnected(switch1.peerInfo)
|
||||||
|
|
||||||
|
await sleepAsync(100.millis)
|
||||||
|
await switch2.disconnect(switch1.peerInfo)
|
||||||
|
await sleepAsync(2.seconds)
|
||||||
|
|
||||||
|
check not switch1.isConnected(switch2.peerInfo)
|
||||||
|
check not switch2.isConnected(switch1.peerInfo)
|
||||||
|
|
||||||
|
var bufferTracker = getTracker(BufferStreamTrackerName)
|
||||||
|
# echo bufferTracker.dump()
|
||||||
|
check bufferTracker.isLeaked() == false
|
||||||
|
|
||||||
|
var connTracker = getTracker(ConnectionTrackerName)
|
||||||
|
# echo connTracker.dump()
|
||||||
|
check connTracker.isLeaked() == false
|
||||||
|
|
||||||
|
check:
|
||||||
|
cycles == {
|
||||||
|
Lifecycle.Connected,
|
||||||
|
Lifecycle.Upgraded,
|
||||||
|
Lifecycle.Disconnected
|
||||||
|
}
|
||||||
|
|
||||||
|
await allFuturesThrowing(
|
||||||
|
switch1.stop(),
|
||||||
|
switch2.stop())
|
||||||
|
await allFuturesThrowing(awaiters)
|
||||||
|
|
||||||
|
waitFor(testSwitch())
|
||||||
|
|
Loading…
Reference in New Issue