Add PubSub observer+ hooks (they can modify as well)

This commit is contained in:
Giovanni Petrantoni 2020-04-30 22:22:31 +09:00 committed by Dmitriy Ryajov
parent 268253ea18
commit c889224012
4 changed files with 52 additions and 15 deletions

View File

@ -138,6 +138,7 @@ method unsubscribe*(f: FloodSub,
await f.sendSubs(p, topics.mapIt(it.topic).deduplicate(), false) await f.sendSubs(p, topics.mapIt(it.topic).deduplicate(), false)
method initPubSub*(f: FloodSub) = method initPubSub*(f: FloodSub) =
procCall PubSub(f).initPubSub()
f.peers = initTable[string, PubSubPeer]() f.peers = initTable[string, PubSubPeer]()
f.topics = initTable[string, Topic]() f.topics = initTable[string, Topic]()
f.floodsub = initTable[string, HashSet[string]]() f.floodsub = initTable[string, HashSet[string]]()

View File

@ -16,6 +16,7 @@ import pubsubpeer,
../../peerinfo ../../peerinfo
export PubSubPeer export PubSubPeer
export PubSubObserver
logScope: logScope:
topic = "PubSub" topic = "PubSub"
@ -42,6 +43,7 @@ type
sign*: bool # enable message signing sign*: bool # enable message signing
cleanupLock: AsyncLock cleanupLock: AsyncLock
validators*: Table[string, HashSet[ValidatorHandler]] validators*: Table[string, HashSet[ValidatorHandler]]
observers: ref seq[PubSubObserver] # ref as in smart_ptr
proc sendSubs*(p: PubSub, proc sendSubs*(p: PubSub,
peer: PubSubPeer, peer: PubSubPeer,
@ -72,6 +74,7 @@ method rpcHandler*(p: PubSub,
rpcMsgs: seq[RPCMsg]) {.async, base.} = rpcMsgs: seq[RPCMsg]) {.async, base.} =
## handle rpc messages ## handle rpc messages
trace "processing RPC message", peer = peer.id, msgs = rpcMsgs.len trace "processing RPC message", peer = peer.id, msgs = rpcMsgs.len
for m in rpcMsgs: # for all RPC messages for m in rpcMsgs: # for all RPC messages
trace "processing messages", msg = m.shortLog trace "processing messages", msg = m.shortLog
if m.subscriptions.len > 0: # if there are any subscriptions if m.subscriptions.len > 0: # if there are any subscriptions
@ -104,6 +107,7 @@ proc getPeer(p: PubSub,
p.peers[peer.id] = peer p.peers[peer.id] = peer
peer.refs.inc # increment reference cound peer.refs.inc # increment reference cound
peer.observers = p.observers
result = peer result = peer
method handleConn*(p: PubSub, method handleConn*(p: PubSub,
@ -201,7 +205,7 @@ method publish*(p: PubSub,
method initPubSub*(p: PubSub) {.base.} = method initPubSub*(p: PubSub) {.base.} =
## perform pubsub initializaion ## perform pubsub initializaion
discard p.observers = new(seq[PubSubObserver])
method start*(p: PubSub) {.async, base.} = method start*(p: PubSub) {.async, base.} =
## start pubsub ## start pubsub
@ -253,3 +257,10 @@ proc newPubSub*(P: typedesc[PubSub],
sign: sign, sign: sign,
cleanupLock: newAsyncLock()) cleanupLock: newAsyncLock())
result.initPubSub() result.initPubSub()
proc addObserver*(p: PubSub; observer: PubSubObserver) = p.observers[] &= observer
proc removeObserver*(p: PubSub; observer: PubSubObserver) =
let idx = p.observers[].find(observer)
if idx != -1:
p.observers[].del(idx)

View File

@ -23,6 +23,10 @@ logScope:
topic = "PubSubPeer" topic = "PubSubPeer"
type type
PubSubObserver* = ref object
onRecv*: proc(peer: PubSubPeer; msgs: var RPCMsg) {.gcsafe.}
onSend*: proc(peer: PubSubPeer; msgs: var RPCMsg) {.gcsafe.}
PubSubPeer* = ref object of RootObj PubSubPeer* = ref object of RootObj
proto: string # the protocol that this peer joined from proto: string # the protocol that this peer joined from
sendConn: Connection sendConn: Connection
@ -33,6 +37,7 @@ type
recvdRpcCache: TimedCache[string] # cache for already received messages recvdRpcCache: TimedCache[string] # cache for already received messages
refs*: int # refcount of the connections this peer is handling refs*: int # refcount of the connections this peer is handling
onConnect: AsyncEvent onConnect: AsyncEvent
observers*: ref seq[PubSubObserver] # ref as in smart_ptr
RPCHandler* = proc(peer: PubSubPeer, msg: seq[RPCMsg]): Future[void] {.gcsafe.} RPCHandler* = proc(peer: PubSubPeer, msg: seq[RPCMsg]): Future[void] {.gcsafe.}
@ -58,8 +63,11 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} =
trace "message already received, skipping", peer = p.id trace "message already received, skipping", peer = p.id
continue continue
let msg = decodeRpcMsg(data) var msg = decodeRpcMsg(data)
trace "decoded msg from peer", peer = p.id, msg = msg.shortLog trace "decoded msg from peer", peer = p.id, msg = msg.shortLog
# trigger hooks
for obs in p.observers[]:
obs.onRecv(p, msg)
await p.handler(p, @[msg]) await p.handler(p, @[msg])
p.recvdRpcCache.put($hexData.hash) p.recvdRpcCache.put($hexData.hash)
except CatchableError as exc: except CatchableError as exc:
@ -71,9 +79,14 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} =
proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} = proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} =
try: try:
for m in msgs: for m in msgs.items:
trace "sending msgs to peer", toPeer = p.id trace "sending msgs to peer", toPeer = p.id
let encoded = encodeRpcMsg(m) let encoded = encodeRpcMsg(m)
# trigger hooks
if p.observers[].len > 0:
var mm = m
for obs in p.observers[]:
obs.onSend(p, mm)
let encodedHex = encoded.buffer.toHex() let encodedHex = encoded.buffer.toHex()
if encoded.buffer.len <= 0: if encoded.buffer.len <= 0:
trace "empty message, skipping", peer = p.id trace "empty message, skipping", peer = p.id

View File

@ -268,6 +268,18 @@ suite "GossipSub":
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
await waitSub(nodes[0], nodes[1], "foobar") await waitSub(nodes[0], nodes[1], "foobar")
var observed = 0
let
obs1 = PubSubObserver(onRecv: proc(peer: PubSubPeer; msgs: var RPCMsg) =
inc observed
)
obs2 = PubSubObserver(onSend: proc(peer: PubSubPeer; msgs: var RPCMsg) =
inc observed
)
nodes[1].pubsub.get().addObserver(obs1)
nodes[0].pubsub.get().addObserver(obs2)
await nodes[0].publish("foobar", cast[seq[byte]]("Hello!")) await nodes[0].publish("foobar", cast[seq[byte]]("Hello!"))
var gossipSub1: GossipSub = GossipSub(nodes[0].pubSub.get()) var gossipSub1: GossipSub = GossipSub(nodes[0].pubSub.get())
@ -283,7 +295,7 @@ suite "GossipSub":
await nodes[1].stop() await nodes[1].stop()
await allFuturesThrowing(wait) await allFuturesThrowing(wait)
result = true result = observed == 2
check: check:
waitFor(runTests()) == true waitFor(runTests()) == true