optionally allow caller to validate messages
This commit is contained in:
parent
0ec8f2fa08
commit
647bb98655
113
floodsub.go
113
floodsub.go
|
@ -54,6 +54,9 @@ type PubSub struct {
|
|||
// topics tracks which topics each of our peers are subscribed to
|
||||
topics map[string]map[peer.ID]struct{}
|
||||
|
||||
// sendMsg handles messages that have been validated
|
||||
sendMsg chan sendReq
|
||||
|
||||
peers map[peer.ID]chan *RPC
|
||||
seenMessages *timecache.TimeCache
|
||||
|
||||
|
@ -91,6 +94,7 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
|
|||
getPeers: make(chan *listPeerReq),
|
||||
addSub: make(chan *addSubReq),
|
||||
getTopics: make(chan *topicReq),
|
||||
sendMsg: make(chan sendReq),
|
||||
myTopics: make(map[string]map[*Subscription]struct{}),
|
||||
topics: make(map[string]map[peer.ID]struct{}),
|
||||
peers: make(map[peer.ID]chan *RPC),
|
||||
|
@ -176,7 +180,19 @@ func (p *PubSub) processLoop(ctx context.Context) {
|
|||
continue
|
||||
}
|
||||
case msg := <-p.publish:
|
||||
p.maybePublishMessage(p.host.ID(), msg.Message)
|
||||
subs := p.getSubscriptions(msg) // call before goroutine!
|
||||
go func() {
|
||||
if p.validate(subs, msg) {
|
||||
p.sendMsg <- sendReq{
|
||||
from: p.host.ID(),
|
||||
msg: msg,
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
case req := <-p.sendMsg:
|
||||
p.maybePublishMessage(req.from, req.msg.Message)
|
||||
|
||||
case <-ctx.Done():
|
||||
log.Info("pubsub processloop shutting down")
|
||||
return
|
||||
|
@ -210,24 +226,22 @@ func (p *PubSub) handleRemoveSubscription(sub *Subscription) {
|
|||
// subscribes to the topic.
|
||||
// Only called from processLoop.
|
||||
func (p *PubSub) handleAddSubscription(req *addSubReq) {
|
||||
subs := p.myTopics[req.topic]
|
||||
sub := req.sub
|
||||
subs := p.myTopics[sub.topic]
|
||||
|
||||
// announce we want this topic
|
||||
if len(subs) == 0 {
|
||||
p.announce(req.topic, true)
|
||||
p.announce(sub.topic, true)
|
||||
}
|
||||
|
||||
// make new if not there
|
||||
if subs == nil {
|
||||
p.myTopics[req.topic] = make(map[*Subscription]struct{})
|
||||
subs = p.myTopics[req.topic]
|
||||
p.myTopics[sub.topic] = make(map[*Subscription]struct{})
|
||||
subs = p.myTopics[sub.topic]
|
||||
}
|
||||
|
||||
sub := &Subscription{
|
||||
ch: make(chan *Message, 32),
|
||||
topic: req.topic,
|
||||
cancelCh: p.cancelCh,
|
||||
}
|
||||
sub.ch = make(chan *Message, 32)
|
||||
sub.cancelCh = p.cancelCh
|
||||
|
||||
p.myTopics[sub.topic][sub] = struct{}{}
|
||||
|
||||
|
@ -314,7 +328,15 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error {
|
|||
continue
|
||||
}
|
||||
|
||||
p.maybePublishMessage(rpc.from, pmsg)
|
||||
subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine!
|
||||
go func() {
|
||||
if p.validate(subs, &Message{pmsg}) {
|
||||
p.sendMsg <- sendReq{
|
||||
from: rpc.from,
|
||||
msg: &*Message{pmsg},
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -324,6 +346,17 @@ func msgID(pmsg *pb.Message) string {
|
|||
return string(pmsg.GetFrom()) + string(pmsg.GetSeqno())
|
||||
}
|
||||
|
||||
// validate is called in a goroutine and calls the validate functions of all subs with msg as parameter.
|
||||
func (p *PubSub) validate(subs []*Subscription, msg *Message) bool {
|
||||
for _, sub := range subs {
|
||||
if sub.validate != nil && !sub.validate(msg) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *PubSub) maybePublishMessage(from peer.ID, pmsg *pb.Message) {
|
||||
id := msgID(pmsg)
|
||||
if p.seenMessage(id) {
|
||||
|
@ -375,20 +408,47 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// getSubscriptions returns all subscriptions the would receive the given message.
|
||||
func (p *PubSub) getSubscriptions(msg *Message) []*Subscription {
|
||||
var subs []*Subscription
|
||||
|
||||
for _, topic := range msg.GetTopicIDs() {
|
||||
tSubs, ok := p.myTopics[topic]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for sub := range tSubs {
|
||||
subs = append(subs, sub)
|
||||
}
|
||||
}
|
||||
|
||||
return subs
|
||||
}
|
||||
|
||||
type addSubReq struct {
|
||||
topic string
|
||||
resp chan *Subscription
|
||||
sub *Subscription
|
||||
resp chan *Subscription
|
||||
}
|
||||
|
||||
// WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further.
|
||||
func WithValidator(validate func(*Message) bool) func(*Subscription) error {
|
||||
return func(sub *Subscription) error {
|
||||
sub.validate = validate
|
||||
return nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Subscribe returns a new Subscription for the given topic
|
||||
func (p *PubSub) Subscribe(topic string) (*Subscription, error) {
|
||||
func (p *PubSub) Subscribe(topic string, opts ...func(*Subscription) error) (*Subscription, error) {
|
||||
td := pb.TopicDescriptor{Name: &topic}
|
||||
|
||||
return p.SubscribeByTopicDescriptor(&td)
|
||||
return p.SubscribeByTopicDescriptor(&td, opts...)
|
||||
}
|
||||
|
||||
// SubscribeByTopicDescriptor lets you subscribe a topic using a pb.TopicDescriptor
|
||||
func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscription, error) {
|
||||
func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...func(*Subscription) error) (*Subscription, error) {
|
||||
if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE {
|
||||
return nil, fmt.Errorf("auth mode not yet supported")
|
||||
}
|
||||
|
@ -397,10 +457,21 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscripti
|
|||
return nil, fmt.Errorf("encryption mode not yet supported")
|
||||
}
|
||||
|
||||
sub := &Subscription{
|
||||
topic: td.GetName(),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
err := opt(sub)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
out := make(chan *Subscription, 1)
|
||||
p.addSub <- &addSubReq{
|
||||
topic: td.GetName(),
|
||||
resp: out,
|
||||
sub: sub,
|
||||
resp: out,
|
||||
}
|
||||
|
||||
return <-out, nil
|
||||
|
@ -439,6 +510,12 @@ type listPeerReq struct {
|
|||
topic string
|
||||
}
|
||||
|
||||
// sendReq is a request to call maybePublishMessage. It is issued after the subscription verification is done.
|
||||
type sendReq struct {
|
||||
from peer.ID
|
||||
msg *Message
|
||||
}
|
||||
|
||||
// ListPeers returns a list of peers we are connected to.
|
||||
func (p *PubSub) ListPeers(topic string) []peer.ID {
|
||||
out := make(chan []peer.ID)
|
||||
|
|
|
@ -323,14 +323,53 @@ func TestOneToOne(t *testing.T) {
|
|||
|
||||
connect(t, hosts[0], hosts[1])
|
||||
|
||||
ch, err := psubs[1].Subscribe("foobar")
|
||||
sub, err := psubs[1].Subscribe("foobar")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
|
||||
checkMessageRouting(t, "foobar", psubs, []*Subscription{ch})
|
||||
checkMessageRouting(t, "foobar", psubs, []*Subscription{sub})
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
hosts := getNetHosts(t, ctx, 2)
|
||||
psubs := getPubsubs(ctx, hosts)
|
||||
|
||||
connect(t, hosts[0], hosts[1])
|
||||
topic := "foobar"
|
||||
|
||||
sub, err := psubs[1].Subscribe(topic, WithValidator(func(msg *Message) bool {
|
||||
return !bytes.Contains(msg.Data, []byte("illegal"))
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
|
||||
data := make([]byte, 16)
|
||||
rand.Read(data)
|
||||
|
||||
data = append(data, []byte("illegal")...)
|
||||
|
||||
for _, p := range psubs {
|
||||
err := p.Publish(topic, data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
select {
|
||||
case msg := <-sub.ch:
|
||||
t.Log(msg)
|
||||
t.Fatal("expected message validation to filter out the message")
|
||||
case <-time.After(333 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) {
|
||||
|
|
|
@ -9,6 +9,7 @@ type Subscription struct {
|
|||
ch chan *Message
|
||||
cancelCh chan<- *Subscription
|
||||
err error
|
||||
validate func(*Message) bool
|
||||
}
|
||||
|
||||
func (sub *Subscription) Topic() string {
|
||||
|
|
Loading…
Reference in New Issue