Add PubSub observer+ hooks (they can modify as well)
This commit is contained in:
parent
268253ea18
commit
c889224012
|
@ -138,6 +138,7 @@ method unsubscribe*(f: FloodSub,
|
||||||
await f.sendSubs(p, topics.mapIt(it.topic).deduplicate(), false)
|
await f.sendSubs(p, topics.mapIt(it.topic).deduplicate(), false)
|
||||||
|
|
||||||
method initPubSub*(f: FloodSub) =
|
method initPubSub*(f: FloodSub) =
|
||||||
|
procCall PubSub(f).initPubSub()
|
||||||
f.peers = initTable[string, PubSubPeer]()
|
f.peers = initTable[string, PubSubPeer]()
|
||||||
f.topics = initTable[string, Topic]()
|
f.topics = initTable[string, Topic]()
|
||||||
f.floodsub = initTable[string, HashSet[string]]()
|
f.floodsub = initTable[string, HashSet[string]]()
|
||||||
|
|
|
@ -16,6 +16,7 @@ import pubsubpeer,
|
||||||
../../peerinfo
|
../../peerinfo
|
||||||
|
|
||||||
export PubSubPeer
|
export PubSubPeer
|
||||||
|
export PubSubObserver
|
||||||
|
|
||||||
logScope:
|
logScope:
|
||||||
topic = "PubSub"
|
topic = "PubSub"
|
||||||
|
@ -42,6 +43,7 @@ type
|
||||||
sign*: bool # enable message signing
|
sign*: bool # enable message signing
|
||||||
cleanupLock: AsyncLock
|
cleanupLock: AsyncLock
|
||||||
validators*: Table[string, HashSet[ValidatorHandler]]
|
validators*: Table[string, HashSet[ValidatorHandler]]
|
||||||
|
observers: ref seq[PubSubObserver] # ref as in smart_ptr
|
||||||
|
|
||||||
proc sendSubs*(p: PubSub,
|
proc sendSubs*(p: PubSub,
|
||||||
peer: PubSubPeer,
|
peer: PubSubPeer,
|
||||||
|
@ -72,6 +74,7 @@ method rpcHandler*(p: PubSub,
|
||||||
rpcMsgs: seq[RPCMsg]) {.async, base.} =
|
rpcMsgs: seq[RPCMsg]) {.async, base.} =
|
||||||
## handle rpc messages
|
## handle rpc messages
|
||||||
trace "processing RPC message", peer = peer.id, msgs = rpcMsgs.len
|
trace "processing RPC message", peer = peer.id, msgs = rpcMsgs.len
|
||||||
|
|
||||||
for m in rpcMsgs: # for all RPC messages
|
for m in rpcMsgs: # for all RPC messages
|
||||||
trace "processing messages", msg = m.shortLog
|
trace "processing messages", msg = m.shortLog
|
||||||
if m.subscriptions.len > 0: # if there are any subscriptions
|
if m.subscriptions.len > 0: # if there are any subscriptions
|
||||||
|
@ -104,6 +107,7 @@ proc getPeer(p: PubSub,
|
||||||
|
|
||||||
p.peers[peer.id] = peer
|
p.peers[peer.id] = peer
|
||||||
peer.refs.inc # increment reference cound
|
peer.refs.inc # increment reference cound
|
||||||
|
peer.observers = p.observers
|
||||||
result = peer
|
result = peer
|
||||||
|
|
||||||
method handleConn*(p: PubSub,
|
method handleConn*(p: PubSub,
|
||||||
|
@ -201,7 +205,7 @@ method publish*(p: PubSub,
|
||||||
|
|
||||||
method initPubSub*(p: PubSub) {.base.} =
|
method initPubSub*(p: PubSub) {.base.} =
|
||||||
## perform pubsub initializaion
|
## perform pubsub initializaion
|
||||||
discard
|
p.observers = new(seq[PubSubObserver])
|
||||||
|
|
||||||
method start*(p: PubSub) {.async, base.} =
|
method start*(p: PubSub) {.async, base.} =
|
||||||
## start pubsub
|
## start pubsub
|
||||||
|
@ -253,3 +257,10 @@ proc newPubSub*(P: typedesc[PubSub],
|
||||||
sign: sign,
|
sign: sign,
|
||||||
cleanupLock: newAsyncLock())
|
cleanupLock: newAsyncLock())
|
||||||
result.initPubSub()
|
result.initPubSub()
|
||||||
|
|
||||||
|
proc addObserver*(p: PubSub; observer: PubSubObserver) = p.observers[] &= observer
|
||||||
|
|
||||||
|
proc removeObserver*(p: PubSub; observer: PubSubObserver) =
|
||||||
|
let idx = p.observers[].find(observer)
|
||||||
|
if idx != -1:
|
||||||
|
p.observers[].del(idx)
|
||||||
|
|
|
@ -23,6 +23,10 @@ logScope:
|
||||||
topic = "PubSubPeer"
|
topic = "PubSubPeer"
|
||||||
|
|
||||||
type
|
type
|
||||||
|
PubSubObserver* = ref object
|
||||||
|
onRecv*: proc(peer: PubSubPeer; msgs: var RPCMsg) {.gcsafe.}
|
||||||
|
onSend*: proc(peer: PubSubPeer; msgs: var RPCMsg) {.gcsafe.}
|
||||||
|
|
||||||
PubSubPeer* = ref object of RootObj
|
PubSubPeer* = ref object of RootObj
|
||||||
proto: string # the protocol that this peer joined from
|
proto: string # the protocol that this peer joined from
|
||||||
sendConn: Connection
|
sendConn: Connection
|
||||||
|
@ -33,6 +37,7 @@ type
|
||||||
recvdRpcCache: TimedCache[string] # cache for already received messages
|
recvdRpcCache: TimedCache[string] # cache for already received messages
|
||||||
refs*: int # refcount of the connections this peer is handling
|
refs*: int # refcount of the connections this peer is handling
|
||||||
onConnect: AsyncEvent
|
onConnect: AsyncEvent
|
||||||
|
observers*: ref seq[PubSubObserver] # ref as in smart_ptr
|
||||||
|
|
||||||
RPCHandler* = proc(peer: PubSubPeer, msg: seq[RPCMsg]): Future[void] {.gcsafe.}
|
RPCHandler* = proc(peer: PubSubPeer, msg: seq[RPCMsg]): Future[void] {.gcsafe.}
|
||||||
|
|
||||||
|
@ -58,8 +63,11 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} =
|
||||||
trace "message already received, skipping", peer = p.id
|
trace "message already received, skipping", peer = p.id
|
||||||
continue
|
continue
|
||||||
|
|
||||||
let msg = decodeRpcMsg(data)
|
var msg = decodeRpcMsg(data)
|
||||||
trace "decoded msg from peer", peer = p.id, msg = msg.shortLog
|
trace "decoded msg from peer", peer = p.id, msg = msg.shortLog
|
||||||
|
# trigger hooks
|
||||||
|
for obs in p.observers[]:
|
||||||
|
obs.onRecv(p, msg)
|
||||||
await p.handler(p, @[msg])
|
await p.handler(p, @[msg])
|
||||||
p.recvdRpcCache.put($hexData.hash)
|
p.recvdRpcCache.put($hexData.hash)
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
|
@ -71,9 +79,14 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} =
|
||||||
|
|
||||||
proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} =
|
proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} =
|
||||||
try:
|
try:
|
||||||
for m in msgs:
|
for m in msgs.items:
|
||||||
trace "sending msgs to peer", toPeer = p.id
|
trace "sending msgs to peer", toPeer = p.id
|
||||||
let encoded = encodeRpcMsg(m)
|
let encoded = encodeRpcMsg(m)
|
||||||
|
# trigger hooks
|
||||||
|
if p.observers[].len > 0:
|
||||||
|
var mm = m
|
||||||
|
for obs in p.observers[]:
|
||||||
|
obs.onSend(p, mm)
|
||||||
let encodedHex = encoded.buffer.toHex()
|
let encodedHex = encoded.buffer.toHex()
|
||||||
if encoded.buffer.len <= 0:
|
if encoded.buffer.len <= 0:
|
||||||
trace "empty message, skipping", peer = p.id
|
trace "empty message, skipping", peer = p.id
|
||||||
|
|
|
@ -268,6 +268,18 @@ suite "GossipSub":
|
||||||
|
|
||||||
await nodes[1].subscribe("foobar", handler)
|
await nodes[1].subscribe("foobar", handler)
|
||||||
await waitSub(nodes[0], nodes[1], "foobar")
|
await waitSub(nodes[0], nodes[1], "foobar")
|
||||||
|
|
||||||
|
var observed = 0
|
||||||
|
let
|
||||||
|
obs1 = PubSubObserver(onRecv: proc(peer: PubSubPeer; msgs: var RPCMsg) =
|
||||||
|
inc observed
|
||||||
|
)
|
||||||
|
obs2 = PubSubObserver(onSend: proc(peer: PubSubPeer; msgs: var RPCMsg) =
|
||||||
|
inc observed
|
||||||
|
)
|
||||||
|
nodes[1].pubsub.get().addObserver(obs1)
|
||||||
|
nodes[0].pubsub.get().addObserver(obs2)
|
||||||
|
|
||||||
await nodes[0].publish("foobar", cast[seq[byte]]("Hello!"))
|
await nodes[0].publish("foobar", cast[seq[byte]]("Hello!"))
|
||||||
|
|
||||||
var gossipSub1: GossipSub = GossipSub(nodes[0].pubSub.get())
|
var gossipSub1: GossipSub = GossipSub(nodes[0].pubSub.get())
|
||||||
|
@ -283,7 +295,7 @@ suite "GossipSub":
|
||||||
await nodes[1].stop()
|
await nodes[1].stop()
|
||||||
await allFuturesThrowing(wait)
|
await allFuturesThrowing(wait)
|
||||||
|
|
||||||
result = true
|
result = observed == 2
|
||||||
|
|
||||||
check:
|
check:
|
||||||
waitFor(runTests()) == true
|
waitFor(runTests()) == true
|
||||||
|
|
Loading…
Reference in New Issue