From 68cc57669e1287cbf47f4a675f11cd9d65bfa71a Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Mon, 16 Dec 2019 23:24:03 -0600 Subject: [PATCH] Feat/pubsub validators (#58) * feat: adding validator hooks to pubsub * expose add/remove validators on switch * do less unnecessary copyng --- libp2p/connection.nim | 3 +- libp2p/protocols/identify.nim | 2 +- libp2p/protocols/pubsub/floodsub.nim | 27 +++-- libp2p/protocols/pubsub/gossipsub.nim | 74 ++++++------ libp2p/protocols/pubsub/pubsub.nim | 87 +++++++++++---- libp2p/protocols/pubsub/pubsubpeer.nim | 8 +- libp2p/protocols/pubsub/rpc/message.nim | 12 +- libp2p/protocols/secure/secio.nim | 9 +- libp2p/switch.nim | 18 +++ tests/pubsub/testfloodsub.nim | 119 ++++++++++++++++++-- tests/pubsub/testgossipsub.nim | 142 ++++++++++++++++++++---- 11 files changed, 385 insertions(+), 116 deletions(-) diff --git a/libp2p/connection.nim b/libp2p/connection.nim index dd359613b..5395be965 100644 --- a/libp2p/connection.nim +++ b/libp2p/connection.nim @@ -141,4 +141,5 @@ method getObservedAddrs*(c: Connection): Future[MultiAddress] {.base, async, gcs result = c.observedAddrs proc `$`*(conn: Connection): string = - result = $(conn.peerInfo) + if not isNil(conn.peerInfo): + result = $(conn.peerInfo) diff --git a/libp2p/protocols/identify.nim b/libp2p/protocols/identify.nim index 43e10833d..3bed4e0bf 100644 --- a/libp2p/protocols/identify.nim +++ b/libp2p/protocols/identify.nim @@ -139,7 +139,7 @@ proc identify*(p: Identify, if peer != remotePeerInfo.peerId: trace "Peer ids don't match", remote = peer.pretty(), - local = remotePeerInfo.get().id + local = remotePeerInfo.id raise newException(IdentityNoMatchError, "Peer ids don't match") diff --git a/libp2p/protocols/pubsub/floodsub.nim b/libp2p/protocols/pubsub/floodsub.nim index 3abb8d65d..73916e10b 100644 --- a/libp2p/protocols/pubsub/floodsub.nim +++ b/libp2p/protocols/pubsub/floodsub.nim @@ -46,26 +46,31 @@ method subscribeTopic*(f: FloodSub, # unsubscribe the peer from the topic f.floodsub[topic].excl(peerId) -method handleDisconnect*(f: FloodSub, peer: PubSubPeer) {.async, gcsafe.} = +method handleDisconnect*(f: FloodSub, peer: PubSubPeer) {.async.} = ## handle peer disconnects for t in f.floodsub.keys: f.floodsub[t].excl(peer.id) method rpcHandler*(f: FloodSub, peer: PubSubPeer, - rpcMsgs: seq[RPCMsg]) {.async, gcsafe.} = - trace "processing RPC message", peer = peer.id, msg = rpcMsgs - for m in rpcMsgs: # for all RPC messages - trace "processing message", msg = rpcMsgs - if m.subscriptions.len > 0: # if there are any subscriptions - for s in m.subscriptions: # subscribe/unsubscribe the peer for each topic - f.subscribeTopic(s.topic, s.subscribe, peer.id) + rpcMsgs: seq[RPCMsg]) {.async.} = + await procCall PubSub(f).rpcHandler(peer, rpcMsgs) + for m in rpcMsgs: # for all RPC messages if m.messages.len > 0: # if there are any messages var toSendPeers: HashSet[string] = initHashSet[string]() for msg in m.messages: # for every message if msg.msgId notin f.seen: f.seen.put(msg.msgId) # add the message to the seen cache + + if not msg.verify(peer.peerInfo): + trace "dropping message due to failed signature verification" + continue + + if not (await f.validate(msg)): + trace "dropping message due to failed validation" + continue + for t in msg.topicIDs: # for every topic in the message if t in f.floodsub: toSendPeers.incl(f.floodsub[t]) # get all the peers interested in this topic @@ -79,7 +84,7 @@ method rpcHandler*(f: FloodSub, await f.peers[p].send(@[RPCMsg(messages: m.messages)]) method init(f: FloodSub) = - proc handler(conn: Connection, proto: string) {.async, gcsafe.} = + proc handler(conn: Connection, proto: string) {.async.} = ## main protocol handler that gets triggered on every ## connection for a protocol string ## e.g. ``/floodsub/1.0.0``, etc... @@ -92,7 +97,7 @@ method init(f: FloodSub) = method publish*(f: FloodSub, topic: string, - data: seq[byte]) {.async, gcsafe.} = + data: seq[byte]) {.async.} = await procCall PubSub(f).publish(topic, data) if data.len <= 0 or topic.len <= 0: @@ -110,7 +115,7 @@ method publish*(f: FloodSub, await f.peers[p].send(@[RPCMsg(messages: @[msg])]) method unsubscribe*(f: FloodSub, - topics: seq[TopicPair]) {.async, gcsafe.} = + topics: seq[TopicPair]) {.async.} = await procCall PubSub(f).unsubscribe(topics) for p in f.peers.values: diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 532a3914f..dd64a6d5b 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -73,7 +73,7 @@ proc addInterval(every: Duration, cb: CallbackFunc, return retFuture method init(g: GossipSub) = - proc handler(conn: Connection, proto: string) {.async, gcsafe.} = + proc handler(conn: Connection, proto: string) {.async.} = ## main protocol handler that gets triggered on every ## connection for a protocol string ## e.g. ``/floodsub/1.0.0``, etc... @@ -84,7 +84,7 @@ method init(g: GossipSub) = g.handler = handler g.codec = GossipSubCodec -method handleDisconnect(g: GossipSub, peer: PubSubPeer) {.async, gcsafe.} = +method handleDisconnect(g: GossipSub, peer: PubSubPeer) {.async.} = ## handle peer disconnects await procCall FloodSub(g).handleDisconnect(peer) for t in g.gossipsub.keys: @@ -161,16 +161,10 @@ proc handleIWant(g: GossipSub, peer: PubSubPeer, iwants: seq[ method rpcHandler(g: GossipSub, peer: PubSubPeer, - rpcMsgs: seq[RPCMsg]) {.async, gcsafe.} = + rpcMsgs: seq[RPCMsg]) {.async.} = await procCall PubSub(g).rpcHandler(peer, rpcMsgs) - trace "processing RPC message", peer = peer.id, msg = rpcMsgs - for m in rpcMsgs: # for all RPC messages - trace "processing messages", msg = rpcMsgs - if m.subscriptions.len > 0: # if there are any subscriptions - for s in m.subscriptions: # subscribe/unsubscribe the peer for each topic - g.subscribeTopic(s.topic, s.subscribe, peer.id) - + for m in rpcMsgs: # for all RPC messages if m.messages.len > 0: # if there are any messages var toSendPeers: HashSet[string] = initHashSet[string]() for msg in m.messages: # for every message @@ -181,6 +175,14 @@ method rpcHandler(g: GossipSub, g.seen.put(msg.msgId) # add the message to the seen cache + if not msg.verify(peer.peerInfo): + trace "dropping message due to failed signature verification" + continue + + if not (await g.validate(msg)): + trace "dropping message due to failed validation" + continue + # this shouldn't happen if g.peerInfo.peerId == msg.fromPeerId(): trace "skipping messages from self", msg = msg.msgId @@ -227,10 +229,9 @@ method rpcHandler(g: GossipSub, if respControl.graft.len > 0 or respControl.prune.len > 0 or respControl.ihave.len > 0 or respControl.iwant.len > 0: - await peer.send(@[RPCMsg(control: some(respControl), - messages: messages)]) + await peer.send(@[RPCMsg(control: some(respControl), messages: messages)]) -proc replenishFanout(g: GossipSub, topic: string) {.async, gcsafe.} = +proc replenishFanout(g: GossipSub, topic: string) {.async.} = ## get fanout peers for a topic trace "about to replenish fanout" if topic notin g.fanout: @@ -246,7 +247,7 @@ proc replenishFanout(g: GossipSub, topic: string) {.async, gcsafe.} = trace "fanout replenished with peers", peers = g.fanout[topic].len -proc rebalanceMesh(g: GossipSub, topic: string) {.async, gcsafe.} = +proc rebalanceMesh(g: GossipSub, topic: string) {.async.} = trace "about to rebalance mesh" # create a mesh topic that we're subscribing to if topic notin g.mesh: @@ -288,7 +289,7 @@ proc rebalanceMesh(g: GossipSub, topic: string) {.async, gcsafe.} = trace "mesh balanced, got peers", peers = g.mesh[topic].len -proc dropFanoutPeers(g: GossipSub) {.async, gcsafe.} = +proc dropFanoutPeers(g: GossipSub) {.async.} = # drop peers that we haven't published to in # GossipSubFanoutTTL seconds for topic in g.lastFanoutPubSub.keys: @@ -334,7 +335,7 @@ proc getGossipPeers(g: GossipSub): Table[string, ControlMessage] {.gcsafe.} = result[id] = ControlMessage() result[id].ihave.add(ihave) -proc heartbeat(g: GossipSub) {.async, gcsafe.} = +proc heartbeat(g: GossipSub) {.async.} = trace "running heartbeat" await g.heartbeatLock.acquire() @@ -353,12 +354,12 @@ proc heartbeat(g: GossipSub) {.async, gcsafe.} = method subscribe*(g: GossipSub, topic: string, - handler: TopicHandler) {.async, gcsafe.} = + handler: TopicHandler) {.async.} = await procCall PubSub(g).subscribe(topic, handler) asyncCheck g.rebalanceMesh(topic) method unsubscribe*(g: GossipSub, - topics: seq[TopicPair]) {.async, gcsafe.} = + topics: seq[TopicPair]) {.async.} = await procCall PubSub(g).unsubscribe(topics) for pair in topics: @@ -372,10 +373,11 @@ method unsubscribe*(g: GossipSub, method publish*(g: GossipSub, topic: string, - data: seq[byte]) {.async, gcsafe.} = + data: seq[byte]) {.async.} = await procCall PubSub(g).publish(topic, data) - trace "about to publish message on topic", name = topic, data = data.toHex() + trace "about to publish message on topic", name = topic, + data = data.toHex() if data.len > 0 and topic.len > 0: var peers: HashSet[string] if topic in g.topics: # if we're subscribed to the topic attempt to build a mesh @@ -453,7 +455,7 @@ when isMainModule and not defined(release): let topic = "foobar" gossipSub.mesh[topic] = initHashSet[string]() - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = + proc writeHandler(data: seq[byte]) {.async.} = discard for i in 0..<15: @@ -480,7 +482,7 @@ when isMainModule and not defined(release): let topic = "foobar" gossipSub.gossipsub[topic] = initHashSet[string]() - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = + proc writeHandler(data: seq[byte]) {.async.} = discard for i in 0..<15: @@ -505,12 +507,12 @@ when isMainModule and not defined(release): let gossipSub = newPubSub(TestGossipSub, PeerInfo.init(PrivateKey.random(RSA))) - proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async, gcsafe.} = + proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} = discard let topic = "foobar" gossipSub.gossipsub[topic] = initHashSet[string]() - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = + proc writeHandler(data: seq[byte]) {.async.} = discard for i in 0..<15: @@ -535,13 +537,13 @@ when isMainModule and not defined(release): let gossipSub = newPubSub(TestGossipSub, PeerInfo.init(PrivateKey.random(RSA))) - proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async, gcsafe.} = + proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} = discard let topic = "foobar" gossipSub.fanout[topic] = initHashSet[string]() gossipSub.lastFanoutPubSub[topic] = Moment.fromNow(100.millis) - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = + proc writeHandler(data: seq[byte]) {.async.} = discard for i in 0..<6: @@ -568,7 +570,7 @@ when isMainModule and not defined(release): let gossipSub = newPubSub(TestGossipSub, PeerInfo.init(PrivateKey.random(RSA))) - proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async, gcsafe.} = + proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} = discard let topic1 = "foobar1" @@ -578,7 +580,7 @@ when isMainModule and not defined(release): gossipSub.lastFanoutPubSub[topic1] = Moment.fromNow(100.millis) gossipSub.lastFanoutPubSub[topic1] = Moment.fromNow(500.millis) - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = + proc writeHandler(data: seq[byte]) {.async.} = discard for i in 0..<6: @@ -608,10 +610,10 @@ when isMainModule and not defined(release): let gossipSub = newPubSub(TestGossipSub, PeerInfo.init(PrivateKey.random(RSA))) - proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async, gcsafe.} = + proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} = discard - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = + proc writeHandler(data: seq[byte]) {.async.} = discard let topic = "foobar" @@ -657,10 +659,10 @@ when isMainModule and not defined(release): let gossipSub = newPubSub(TestGossipSub, PeerInfo.init(PrivateKey.random(RSA))) - proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async, gcsafe.} = + proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} = discard - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = + proc writeHandler(data: seq[byte]) {.async.} = discard let topic = "foobar" @@ -689,10 +691,10 @@ when isMainModule and not defined(release): let gossipSub = newPubSub(TestGossipSub, PeerInfo.init(PrivateKey.random(RSA))) - proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async, gcsafe.} = + proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} = discard - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = + proc writeHandler(data: seq[byte]) {.async.} = discard let topic = "foobar" @@ -721,10 +723,10 @@ when isMainModule and not defined(release): let gossipSub = newPubSub(TestGossipSub, PeerInfo.init(PrivateKey.random(RSA))) - proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async, gcsafe.} = + proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} = discard - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = + proc writeHandler(data: seq[byte]) {.async.} = discard let topic = "foobar" diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index c591dff79..e9250795c 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -7,7 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import tables, options, sequtils +import tables, sequtils, sets import chronos, chronicles import pubsubpeer, rpc/messages, @@ -22,8 +22,11 @@ logScope: topic = "PubSub" type - TopicHandler* = proc (topic: string, - data: seq[byte]): Future[void] {.gcsafe.} + TopicHandler* = proc(topic: string, + data: seq[byte]): Future[void] {.gcsafe.} + + ValidatorHandler* = proc(topic: string, + message: Message): Future[bool] {.closure.} TopicPair* = tuple[topic: string, handler: TopicHandler] @@ -37,11 +40,12 @@ type peers*: Table[string, PubSubPeer] # peerid to peer map triggerSelf*: bool # trigger own local handler on publish cleanupLock: AsyncLock + validators*: Table[string, HashSet[ValidatorHandler]] proc sendSubs*(p: PubSub, peer: PubSubPeer, topics: seq[string], - subscribe: bool) {.async, gcsafe.} = + subscribe: bool) {.async.} = ## send subscriptions to remote peer trace "sending subscriptions", peer = peer.id, subscribe = subscribe, @@ -56,13 +60,24 @@ proc sendSubs*(p: PubSub, await peer.send(@[msg]) -method rpcHandler*(p: PubSub, - peer: PubSubPeer, - rpcMsgs: seq[RPCMsg]) {.async, base, gcsafe.} = - ## handle rpc messages +method subscribeTopic*(p: PubSub, + topic: string, + subscribe: bool, + peerId: string) {.base.} = discard -method handleDisconnect*(p: PubSub, peer: PubSubPeer) {.async, base, gcsafe.} = +method rpcHandler*(p: PubSub, + peer: PubSubPeer, + rpcMsgs: seq[RPCMsg]) {.async, base.} = + ## handle rpc messages + trace "processing RPC message", peer = peer.id, msg = rpcMsgs + for m in rpcMsgs: # for all RPC messages + trace "processing messages", msg = rpcMsgs + if m.subscriptions.len > 0: # if there are any subscriptions + for s in m.subscriptions: # subscribe/unsubscribe the peer for each topic + p.subscribeTopic(s.topic, s.subscribe, peer.id) + +method handleDisconnect*(p: PubSub, peer: PubSubPeer) {.async, base.} = ## handle peer disconnects if peer.id in p.peers: p.peers.del(peer.id) @@ -90,7 +105,7 @@ proc getPeer(p: PubSub, peerInfo: PeerInfo, proto: string): PubSubPeer = method handleConn*(p: PubSub, conn: Connection, - proto: string) {.base, async, gcsafe.} = + proto: string) {.base, async.} = ## handle incoming connections ## ## this proc will: @@ -107,7 +122,7 @@ method handleConn*(p: PubSub, await conn.close() return - proc handler(peer: PubSubPeer, msgs: seq[RPCMsg]) {.async, gcsafe.} = + proc handler(peer: PubSubPeer, msgs: seq[RPCMsg]) {.async.} = # call floodsub rpc handler await p.rpcHandler(peer, msgs) @@ -122,7 +137,7 @@ method handleConn*(p: PubSub, await p.cleanUpHelper(peer) method subscribeToPeer*(p: PubSub, - conn: Connection) {.base, async, gcsafe.} = + conn: Connection) {.base, async.} = var peer = p.getPeer(conn.peerInfo, p.codec) trace "setting connection for peer", peerId = conn.peerInfo.id if not peer.isConnected: @@ -137,7 +152,7 @@ method subscribeToPeer*(p: PubSub, asyncCheck p.cleanUpHelper(peer) method unsubscribe*(p: PubSub, - topics: seq[TopicPair]) {.base, async, gcsafe.} = + topics: seq[TopicPair]) {.base, async.} = ## unsubscribe from a list of ``topic`` strings for t in topics: for i, h in p.topics[t.topic].handler: @@ -146,19 +161,13 @@ method unsubscribe*(p: PubSub, method unsubscribe*(p: PubSub, topic: string, - handler: TopicHandler): Future[void] {.base, gcsafe.} = + handler: TopicHandler): Future[void] {.base.} = ## unsubscribe from a ``topic`` string result = p.unsubscribe(@[(topic, handler)]) -method subscribeTopic*(p: PubSub, - topic: string, - subscribe: bool, - peerId: string) {.base, gcsafe.} = - discard - method subscribe*(p: PubSub, topic: string, - handler: TopicHandler) {.base, async, gcsafe.} = + handler: TopicHandler) {.base, async.} = ## subscribe to a topic ## ## ``topic`` - a string topic to subscribe to @@ -178,7 +187,7 @@ method subscribe*(p: PubSub, method publish*(p: PubSub, topic: string, - data: seq[byte]) {.base, async, gcsafe.} = + data: seq[byte]) {.base, async.} = ## publish to a ``topic`` if p.triggerSelf and topic in p.topics: for h in p.topics[topic].handler: @@ -190,14 +199,44 @@ method initPubSub*(p: PubSub) {.base.} = method start*(p: PubSub) {.async, base.} = ## start pubsub - ## start long running/repeating procedures discard method stop*(p: PubSub) {.async, base.} = ## stopt pubsub - ## stop long running/repeating procedures discard +method addValidator*(p: PubSub, + topic: varargs[string], + hook: ValidatorHandler) {.base.} = + for t in topic: + if t notin p.validators: + p.validators[t] = initHashSet[ValidatorHandler]() + + trace "adding validator for topic", topicId = t + p.validators[t].incl(hook) + +method removeValidator*(p: PubSub, + topic: varargs[string], + hook: ValidatorHandler) {.base.} = + for t in topic: + if t in p.validators: + p.validators[t].excl(hook) + +method validate*(p: PubSub, message: Message): Future[bool] {.async, base.} = + var pending: seq[Future[bool]] + trace "about to validate message" + for topic in message.topicIDs: + trace "looking for validators on topic", topicID = topic, + registered = toSeq(p.validators.keys) + if topic in p.validators: + trace "running validators for topic", topicID = topic + # TODO: add timeout to validator + pending.add(p.validators[topic].mapIt(it(topic, message))) + + await allFutures(pending) # await all futures + if pending.allIt(it.read()): # if there are failed + result = true + proc newPubSub*(p: typedesc[PubSub], peerInfo: PeerInfo, triggerSelf: bool = false): p = diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index 1a5c440a6..eb8c091d4 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -45,7 +45,7 @@ proc `conn=`*(p: PubSubPeer, conn: Connection) = p.sendConn = conn p.onConnect.fire() -proc handle*(p: PubSubPeer, conn: Connection) {.async, gcsafe.} = +proc handle*(p: PubSubPeer, conn: Connection) {.async.} = trace "handling pubsub rpc", peer = p.id, closed = conn.closed try: while not conn.closed: @@ -66,7 +66,7 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async, gcsafe.} = finally: trace "exiting pubsub peer read loop", peer = p.id -proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async, gcsafe.} = +proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} = try: for m in msgs: trace "sending msgs to peer", toPeer = p.id @@ -105,12 +105,12 @@ proc sendMsg*(p: PubSubPeer, data: seq[byte]): Future[void] {.gcsafe.} = p.send(@[RPCMsg(messages: @[newMessage(p.peerInfo, data, topic)])]) -proc sendGraft*(p: PubSubPeer, topics: seq[string]) {.async, gcsafe.} = +proc sendGraft*(p: PubSubPeer, topics: seq[string]) {.async.} = for topic in topics: trace "sending graft msg to peer", peer = p.id, topicID = topic await p.send(@[RPCMsg(control: some(ControlMessage(graft: @[ControlGraft(topicID: topic)])))]) -proc sendPrune*(p: PubSubPeer, topics: seq[string]) {.async, gcsafe.} = +proc sendPrune*(p: PubSubPeer, topics: seq[string]) {.async.} = for topic in topics: trace "sending prune msg to peer", peer = p.id, topicID = topic await p.send(@[RPCMsg(control: some(ControlMessage(prune: @[ControlPrune(topicID: topic)])))]) diff --git a/libp2p/protocols/pubsub/rpc/message.nim b/libp2p/protocols/pubsub/rpc/message.nim index 9bceb9a49..66339504c 100644 --- a/libp2p/protocols/pubsub/rpc/message.nim +++ b/libp2p/protocols/pubsub/rpc/message.nim @@ -27,17 +27,16 @@ proc msgId*(m: Message): string = proc fromPeerId*(m: Message): PeerId = PeerID.init(m.fromPeer) -proc sign*(p: PeerInfo, msg: Message): Message {.gcsafe.} = +proc sign*(msg: Message, p: PeerInfo): Message {.gcsafe.} = var buff = initProtoBuffer() encodeMessage(msg, buff) - let prefix = cast[seq[byte]](PubSubPrefix) if buff.buffer.len > 0: result = msg result.signature = p.privateKey. - sign(prefix & buff.buffer). + sign(cast[seq[byte]](PubSubPrefix) & buff.buffer). getBytes() -proc verify*(p: PeerInfo, m: Message): bool = +proc verify*(m: Message, p: PeerInfo): bool = if m.signature.len > 0 and m.key.len > 0: var msg = m msg.signature = @[] @@ -49,7 +48,8 @@ proc verify*(p: PeerInfo, m: Message): bool = var remote: Signature var key: PublicKey if remote.init(m.signature) and key.init(m.key): - result = remote.verify(buff.buffer, key) + trace "verifying signature", remoteSignature = remote + result = remote.verify(cast[seq[byte]](PubSubPrefix) & buff.buffer, key) proc newMessage*(p: PeerInfo, data: seq[byte], @@ -64,6 +64,6 @@ proc newMessage*(p: PeerInfo, seqno: seqno, topicIDs: @[name]) if sign: - result = p.sign(result) + result = result.sign(p) result.key = key diff --git a/libp2p/protocols/secure/secio.nim b/libp2p/protocols/secure/secio.nim index a1f0a332f..55ea08924 100644 --- a/libp2p/protocols/secure/secio.nim +++ b/libp2p/protocols/secure/secio.nim @@ -437,16 +437,15 @@ proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async, gcsafe. var stream = newBufferStream(writeHandler) asyncCheck readLoop(sconn, stream) - var secured = newConnection(stream) - secured.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get()) - result = secured - - secured.closeEvent.wait() + result = newConnection(stream) + result.closeEvent.wait() .addCallback do (udata: pointer): trace "wrapped connection closed, closing upstream" if not isNil(sconn) and not sconn.closed: asyncCheck sconn.close() + result.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get()) + method init(s: Secio) {.gcsafe.} = proc handle(conn: Connection, proto: string) {.async, gcsafe.} = trace "handling connection" diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 358d067c6..995b7fbb3 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -314,6 +314,24 @@ proc publish*(s: Switch, topic: string, data: seq[byte]): Future[void] {.gcsafe. result = s.pubSub.get().publish(topic, data) +proc addValidator*(s: Switch, + topics: varargs[string], + hook: ValidatorHandler) = + # add validator + if s.pubSub.isNone: + raise newNoPubSubException() + + s.pubSub.get().addValidator(topics, hook) + +proc removeValidator*(s: Switch, + topics: varargs[string], + hook: ValidatorHandler) = + # pubslish to pubsub topic + if s.pubSub.isNone: + raise newNoPubSubException() + + s.pubSub.get().removeValidator(topics, hook) + proc newSwitch*(peerInfo: PeerInfo, transports: seq[Transport], identity: Identify, diff --git a/tests/pubsub/testfloodsub.nim b/tests/pubsub/testfloodsub.nim index 937fc1fff..b6b603d1c 100644 --- a/tests/pubsub/testfloodsub.nim +++ b/tests/pubsub/testfloodsub.nim @@ -10,11 +10,15 @@ import unittest, sequtils, options import chronos import utils, - ../../libp2p/[switch, crypto/crypto] + ../../libp2p/[switch, + crypto/crypto, + protocols/pubsub/pubsub, + protocols/pubsub/rpc/messages, + protocols/pubsub/rpc/message] suite "FloodSub": test "FloodSub basic publish/subscribe A -> B": - proc testBasicPubSub(): Future[bool] {.async.} = + proc runTests(): Future[bool] {.async.} = var completionFut = newFuture[bool]() proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = check topic == "foobar" @@ -36,10 +40,10 @@ suite "FloodSub": await allFutures(awaiters) check: - waitFor(testBasicPubSub()) == true + waitFor(runTests()) == true test "FloodSub basic publish/subscribe B -> A": - proc testBasicPubSub(): Future[bool] {.async.} = + proc runTests(): Future[bool] {.async.} = var completionFut = newFuture[bool]() proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = check topic == "foobar" @@ -61,10 +65,107 @@ suite "FloodSub": await allFutures(awaiters) check: - waitFor(testBasicPubSub()) == true + waitFor(runTests()) == true + + test "FloodSub validation should succeed": + proc runTests(): Future[bool] {.async.} = + var handlerFut = newFuture[bool]() + proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = + check topic == "foobar" + handlerFut.complete(true) + + var nodes = generateNodes(2) + var awaiters: seq[Future[void]] + awaiters.add((await nodes[0].start())) + awaiters.add((await nodes[1].start())) + + await subscribeNodes(nodes) + await nodes[1].subscribe("foobar", handler) + await sleepAsync(1000.millis) + + var validatorFut = newFuture[bool]() + proc validator(topic: string, + message: Message): Future[bool] {.async.} = + check topic == "foobar" + validatorFut.complete(true) + result = true + + nodes[1].addValidator("foobar", validator) + await nodes[0].publish("foobar", cast[seq[byte]]("Hello!")) + + await allFutures(handlerFut, handlerFut) + await allFutures(nodes[0].stop(), nodes[1].stop()) + await allFutures(awaiters) + result = true + check: + waitFor(runTests()) == true + + test "FloodSub validation should fail": + proc runTests(): Future[bool] {.async.} = + proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = + check false # if we get here, it should fail + + var nodes = generateNodes(2) + var awaiters: seq[Future[void]] + awaiters.add((await nodes[0].start())) + awaiters.add((await nodes[1].start())) + + await subscribeNodes(nodes) + await nodes[1].subscribe("foobar", handler) + await sleepAsync(100.millis) + + var validatorFut = newFuture[bool]() + proc validator(topic: string, + message: Message): Future[bool] {.async.} = + validatorFut.complete(true) + result = false + + nodes[1].addValidator("foobar", validator) + await nodes[0].publish("foobar", cast[seq[byte]]("Hello!")) + await allFutures(nodes[0].stop(), nodes[1].stop()) + await allFutures(awaiters) + result = true + + check: + waitFor(runTests()) == true + + test "FloodSub validation one fails and one succeeds": + proc runTests(): Future[bool] {.async.} = + var handlerFut = newFuture[bool]() + proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = + check topic == "foo" + handlerFut.complete(true) + + var nodes = generateNodes(2) + var awaiters: seq[Future[void]] + awaiters.add((await nodes[0].start())) + awaiters.add((await nodes[1].start())) + + await subscribeNodes(nodes) + await nodes[1].subscribe("foo", handler) + await nodes[1].subscribe("bar", handler) + await sleepAsync(1000.millis) + + proc validator(topic: string, + message: Message): Future[bool] {.async.} = + if topic == "foo": + result = true + else: + result = false + + nodes[1].addValidator("foo", "bar", validator) + await nodes[0].publish("foo", cast[seq[byte]]("Hello!")) + await nodes[0].publish("bar", cast[seq[byte]]("Hello!")) + + await sleepAsync(100.millis) + await allFutures(nodes[0].stop(), nodes[1].stop()) + await allFutures(awaiters) + result = true + check: + waitFor(runTests()) == true test "FloodSub multiple peers, no self trigger": - proc testBasicFloodSub(): Future[bool] {.async.} = + proc runTests(): Future[bool] {.async.} = var passed: int proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = check topic == "foobar" @@ -93,10 +194,10 @@ suite "FloodSub": result = passed >= 10 # non deterministic, so at least 2 times check: - waitFor(testBasicFloodSub()) == true + waitFor(runTests()) == true test "FloodSub multiple peers, with self trigger": - proc testBasicFloodSub(): Future[bool] {.async.} = + proc runTests(): Future[bool] {.async.} = var passed: int proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = check topic == "foobar" @@ -127,4 +228,4 @@ suite "FloodSub": result = passed >= 10 # non deterministic, so at least 20 times check: - waitFor(testBasicFloodSub()) == true + waitFor(runTests()) == true diff --git a/tests/pubsub/testgossipsub.nim b/tests/pubsub/testgossipsub.nim index e8be14ee2..958b68d05 100644 --- a/tests/pubsub/testgossipsub.nim +++ b/tests/pubsub/testgossipsub.nim @@ -9,22 +9,126 @@ import unittest, sequtils, options, tables, sets import chronos -import utils, ../../libp2p/[switch, - peer, +import utils, ../../libp2p/[peer, peerinfo, connection, crypto/crypto, stream/bufferstream, protocols/pubsub/pubsub, - protocols/pubsub/gossipsub] + protocols/pubsub/gossipsub, + protocols/pubsub/rpc/messages] + proc createGossipSub(): GossipSub = var peerInfo = PeerInfo.init(PrivateKey.random(RSA)) result = newPubSub(GossipSub, peerInfo) suite "GossipSub": + test "GossipSub validation should succeed": + proc runTests(): Future[bool] {.async.} = + var handlerFut = newFuture[bool]() + proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = + check topic == "foobar" + handlerFut.complete(true) + + var nodes = generateNodes(2, true) + var awaiters: seq[Future[void]] + awaiters.add((await nodes[0].start())) + awaiters.add((await nodes[1].start())) + + await subscribeNodes(nodes) + await nodes[1].subscribe("foobar", handler) + await sleepAsync(1000.millis) + + var validatorFut = newFuture[bool]() + proc validator(topic: string, + message: Message): + Future[bool] {.async.} = + check topic == "foobar" + validatorFut.complete(true) + result = true + + nodes[1].addValidator("foobar", validator) + await nodes[0].publish("foobar", cast[seq[byte]]("Hello!")) + + await allFutures(handlerFut, handlerFut) + await allFutures(nodes[0].stop(), nodes[1].stop()) + await allFutures(awaiters) + result = true + check: + waitFor(runTests()) == true + + test "GossipSub validation should fail": + proc runTests(): Future[bool] {.async.} = + proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = + check false # if we get here, it should fail + + var nodes = generateNodes(2, true) + var awaiters: seq[Future[void]] + awaiters.add((await nodes[0].start())) + awaiters.add((await nodes[1].start())) + + await subscribeNodes(nodes) + await nodes[1].subscribe("foobar", handler) + await sleepAsync(100.millis) + + var validatorFut = newFuture[bool]() + proc validator(topic: string, + message: Message): + Future[bool] {.async.} = + validatorFut.complete(true) + result = false + + nodes[1].addValidator("foobar", validator) + await nodes[0].publish("foobar", cast[seq[byte]]("Hello!")) + + await sleepAsync(100.millis) + discard await validatorFut + await allFutures(nodes[0].stop(), nodes[1].stop()) + await allFutures(awaiters) + result = true + + check: + waitFor(runTests()) == true + + test "GossipSub validation one fails and one succeeds": + proc runTests(): Future[bool] {.async.} = + var handlerFut = newFuture[bool]() + proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = + check topic == "foo" + handlerFut.complete(true) + + var nodes = generateNodes(2, true) + var awaiters: seq[Future[void]] + awaiters.add((await nodes[0].start())) + awaiters.add((await nodes[1].start())) + + await subscribeNodes(nodes) + await nodes[1].subscribe("foo", handler) + await nodes[1].subscribe("bar", handler) + await sleepAsync(1000.millis) + + proc validator(topic: string, + message: Message): + Future[bool] {.async.} = + if topic == "foo": + result = true + else: + result = false + + nodes[1].addValidator("foo", "bar", validator) + await nodes[0].publish("foo", cast[seq[byte]]("Hello!")) + await nodes[0].publish("bar", cast[seq[byte]]("Hello!")) + + await sleepAsync(100.millis) + await allFutures(nodes[0].stop(), nodes[1].stop()) + await allFutures(awaiters) + result = true + check: + waitFor(runTests()) == true + test "should add remote peer topic subscriptions": - proc testRun(): Future[bool] {.async.} = + proc runTests(): Future[bool] {.async.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = discard @@ -54,7 +158,7 @@ suite "GossipSub": result = true check: - waitFor(testRun()) == true + waitFor(runTests()) == true test "e2e - should add remote peer topic subscriptions": proc testBasicGossipSub(): Future[bool] {.async.} = @@ -91,7 +195,7 @@ suite "GossipSub": waitFor(testBasicGossipSub()) == true test "should add remote peer topic subscriptions if both peers are subscribed": - proc testRun(): Future[bool] {.async.} = + proc runTests(): Future[bool] {.async.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = discard @@ -134,7 +238,7 @@ suite "GossipSub": result = true check: - waitFor(testRun()) == true + waitFor(runTests()) == true test "e2e - should add remote peer topic subscriptions if both peers are subscribed": proc testBasicGossipSub(): Future[bool] {.async.} = @@ -179,7 +283,7 @@ suite "GossipSub": waitFor(testBasicGossipSub()) == true # test "send over fanout A -> B": - # proc testRun(): Future[bool] {.async.} = + # proc runTests(): Future[bool] {.async.} = # var handlerFut = newFuture[bool]() # proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = # check: @@ -216,10 +320,10 @@ suite "GossipSub": # result = await handlerFut # check: - # waitFor(testRun()) == true + # waitFor(runTests()) == true test "e2e - send over fanout A -> B": - proc testRun(): Future[bool] {.async.} = + proc runTests(): Future[bool] {.async.} = var passed: bool proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = check topic == "foobar" @@ -250,10 +354,10 @@ suite "GossipSub": result = passed check: - waitFor(testRun()) == true + waitFor(runTests()) == true # test "send over mesh A -> B": - # proc testRun(): Future[bool] {.async.} = + # proc runTests(): Future[bool] {.async.} = # var passed: bool # proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = # check: @@ -289,10 +393,10 @@ suite "GossipSub": # result = passed # check: - # waitFor(testRun()) == true + # waitFor(runTests()) == true # test "e2e - send over mesh A -> B": - # proc testRun(): Future[bool] {.async.} = + # proc runTests(): Future[bool] {.async.} = # var passed: bool # proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = # check topic == "foobar" @@ -318,10 +422,10 @@ suite "GossipSub": # result = passed # check: - # waitFor(testRun()) == true + # waitFor(runTests()) == true # test "with multiple peers": - # proc testRun(): Future[bool] {.async.} = + # proc runTests(): Future[bool] {.async.} = # var nodes: seq[GossipSub] # for i in 0..<10: # nodes.add(createGossipSub()) @@ -376,10 +480,10 @@ suite "GossipSub": # result = true # check: - # waitFor(testRun()) == true + # waitFor(runTests()) == true test "e2e - with multiple peers": - proc testRun(): Future[bool] {.async.} = + proc runTests(): Future[bool] {.async.} = var nodes: seq[Switch] = newSeq[Switch]() var awaitters: seq[Future[void]] @@ -419,4 +523,4 @@ suite "GossipSub": result = true check: - waitFor(testRun()) == true + waitFor(runTests()) == true